Compare commits

...

24 Commits

Author SHA1 Message Date
dependabot[bot]
8143834258 Bump github.com/pion/dtls/v3 from 3.0.9 to 3.0.11
Bumps [github.com/pion/dtls/v3](https://github.com/pion/dtls) from 3.0.9 to 3.0.11.
- [Release notes](https://github.com/pion/dtls/releases)
- [Commits](https://github.com/pion/dtls/compare/v3.0.9...v3.0.11)

---
updated-dependencies:
- dependency-name: github.com/pion/dtls/v3
  dependency-version: 3.0.11
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-02-13 14:20:47 +00:00
Bethuel Mmbaga
d3eeb6d8ee [misc] Add cloud api spec to public open api with rest client (#5222) 2026-02-13 15:08:47 +03:00
Bethuel Mmbaga
7ebf37ef20 [management] Enforce access control on accessible peers (#5301) 2026-02-13 12:46:43 +03:00
Misha Bragin
64b849c801 [self-hosted] add netbird server (#5232)
* Unified NetBird combined server (Management, Signal, Relay, STUN) as a single executable with richer YAML configuration, validation, and defaults.
  * Official Dockerfile/image for single-container deployment.
  * Optional in-process profiling endpoint for diagnostics.
  * Multiplexing to route HTTP/gRPC/WebSocket traffic via one port; runtime hooks to inject custom handlers.
* **Chores**
  * Updated deployment scripts, compose files, and reverse-proxy templates to target the combined server; added example configs and getting-started updates.
2026-02-12 19:24:43 +01:00
Maycon Santos
69d4b5d821 [misc] Update sign pipeline version (#5296) 2026-02-12 11:31:49 +01:00
Viktor Liu
3dfa97dcbd [client] Fix stale entries in nftables with no handle (#5272) 2026-02-12 09:15:57 +01:00
Viktor Liu
1ddc9ce2bf [client] Fix nil pointer panic in device and engine code (#5287) 2026-02-12 09:15:42 +01:00
Maycon Santos
2de1949018 [client] Check if login is required on foreground mode (#5295) 2026-02-11 21:42:36 +01:00
Vlad
fc88399c23 [management] fixed ischild check (#5279) 2026-02-10 20:31:15 +03:00
Zoltan Papp
6981fdce7e [client] Fix race condition and ensure correct message ordering in Relay (#5265)
* Fix race condition and ensure correct message ordering in
connection establishment

Reorder operations in OpenConn to register the connection before
waiting for peer availability. This ensures:

- Connection is ready to receive messages before peer subscription
completes
- Transport messages and onconnected events maintain proper ordering
- No messages are lost during the connection establishment window
- Concurrent OpenConn calls cannot create duplicate connections

If peer availability check fails, the pre-registered connection is
properly cleaned up.

* Handle service shutdown during relay connection initialization

Ensure relay connections are properly cleaned up when the service is not running by verifying `serviceIsRunning` and removing stale entries from `c.conns` to prevent unintended behaviors.
2026-02-09 11:34:24 +01:00
Viktor Liu
08403f64aa [client] Add env var to skip DNS probing (#5270) 2026-02-09 11:09:11 +01:00
Viktor Liu
391221a986 [client] Fix uspfilter duplicate firewall rules (#5269) 2026-02-09 10:14:02 +01:00
Zoltan Papp
7bc85107eb Adds timing measurement to handleSync to help diagnose sync performance issues (#5228) 2026-02-06 19:50:48 +01:00
Zoltan Papp
3be16d19a0 [management] Feature/grpc debounce msgtype (#5239)
* Add gRPC update debouncing mechanism

Implements backpressure handling for peer network map updates to
efficiently handle rapid changes. First update is sent immediately,
subsequent rapid updates are coalesced, ensuring only the latest
update is sent after a 1-second quiet period.

* Enhance unit test to verify peer count synchronization with debouncing and timeout handling

* Debounce based on type

* Refactor test to validate timer restart after pending update dispatch

* Simplify timer reset for Go 1.23+ automatic channel draining

Remove manual channel drain in resetTimer() since Go 1.23+ automatically
drains the timer channel when Stop() returns false, making the
select-case pattern unnecessary.
2026-02-06 19:47:38 +01:00
Vlad
af8f730bda [management] check stream start time for connecting peer (#5267) 2026-02-06 18:00:43 +01:00
eyJhb
c3f176f348 [client] Fix wrong URL being logged for DefaultAdminURL (#5252)
- DefaultManagementURL was being logged instead of DefaultAdminURL
2026-02-06 11:23:36 +01:00
Viktor Liu
0119f3e9f4 [client] Fix netstack detection and add wireguard port option (#5251)
- Add WireguardPort option to embed.Options for custom port configuration
- Fix KernelInterface detection to account for netstack mode
- Skip SSH config updates when running in netstack mode
- Skip interface removal wait when running in netstack mode
- Use BindListener for netstack to avoid port conflicts on same host
2026-02-06 10:03:01 +01:00
Viktor Liu
1b96648d4d [client] Always log dns forwader responses (#5262) 2026-02-05 14:34:35 +01:00
Zoltan Papp
d2f9653cea Fix nil pointer panic in ICE agent during sleep/wake cycles (#5261)
Add defensive nil checks in ThreadSafeAgent.Close() to prevent panic
when agent field is nil. This can occur during Windows suspend/resume
when network interfaces are disrupted or the pion/ice library returns
nil without error.

Also capture agent pointer in local variable before goroutine execution
to prevent race conditions.

Fixes service crashes on laptop wake-up.
2026-02-05 12:06:28 +01:00
Zoltan Papp
194a986926 Cache the result of wgInterface.ToInterface() using sync.Once (#5256)
Avoid repeated conversions during route setup. The toInterface helper ensures
the conversion happens only once regardless of how many routes are added
or removed.
2026-02-04 22:22:37 +01:00
Viktor Liu
f7732557fa [client] Add missing bsd flags in debug bundle (#5254) 2026-02-04 18:07:27 +01:00
Vlad
d488f58311 [management] fix set disconnected status for connected peer (#5247) 2026-02-04 11:44:46 +01:00
Pascal Fischer
6fdc00ff41 [management] adding account id validation to accessible peers handler (#5246) 2026-02-03 17:30:02 +01:00
Misha Bragin
b20d484972 [docs] Add selfhosting video (#5235) 2026-02-01 16:06:36 +01:00
100 changed files with 11992 additions and 1124 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.0"
SIGN_PIPE_VER: "v0.1.1"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -106,6 +106,26 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-server
dir: combined
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-server
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
@@ -520,6 +540,55 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
@@ -598,6 +667,18 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
@@ -675,6 +756,19 @@ docker_manifests:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
brews:
- ids:
- default

View File

@@ -60,8 +60,8 @@
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Self-Host NetBird (Video)
[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ)
### Key features

View File

@@ -282,13 +282,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
}
defer authClient.Close()
needsLogin := false
err, isAuthError := authClient.Login(ctx, "", "")
if isAuthError {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
needsLogin, err := authClient.IsLoginRequired(ctx)
if err != nil {
return fmt.Errorf("check login required: %v", err)
}
jwtToken := ""

View File

@@ -71,6 +71,8 @@ type Options struct {
DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
WireguardPort *int
}
// validateCredentials checks that exactly one credential type is provided
@@ -140,6 +142,7 @@ func New(opts Options) (*Client, error) {
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
WireguardPort: opts.WireguardPort,
}
if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input)

View File

@@ -483,7 +483,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
}
if nftRule.Handle == 0 {
return fmt.Errorf("route rule %s has no handle", ruleKey)
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(nftRule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
@@ -660,13 +665,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
}
if err := r.conn.Flush(); err != nil {
// TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
r.rollbackRules(pair)
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
}
return nil
}
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
func (r *router) rollbackRules(pair firewall.RouterPair) {
keys := []string{
firewall.GenKey(firewall.ForwardingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
}
for _, key := range keys {
rule, ok := r.rules[key]
if !ok {
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("rollback set counter for %s: %v", key, err)
}
delete(r.rules, key)
}
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
@@ -928,18 +952,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
rule, exists := r.rules[ruleKey]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
@@ -1329,65 +1365,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
}
// Set counters are decremented in the sub-methods above before flush. If flush fails,
// counters will be off until the next successful removal or refresh cycle.
if err := r.conn.Flush(); err != nil {
// TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
}
return nil
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else {
rule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("prerouting rule %s not found", ruleKey)
return nil
}
if rule.Handle == 0 {
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
// (e.g. from failed flushes) and updates handles for all existing rules.
func (r *router) refreshRulesMap() error {
var merr *multierror.Error
newRules := make(map[string]*nftables.Rule)
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("list rules: %w", err)
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
// preserve existing entries for this chain since we can't verify their state
for k, v := range r.rules {
if v.Chain != nil && v.Chain.Name == chain.Name {
newRules[k] = v
}
}
continue
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
newRules[string(rule.UserData)] = rule
}
}
}
return nil
r.rules = newRules
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
@@ -1629,20 +1689,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
}
var merr *multierror.Error
var needsFlush bool
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.conn.DelRule(dnatRule); err != nil {
if dnatRule.Handle == 0 {
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
delete(r.rules, ruleKey+dnatSuffix)
} else if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
} else {
needsFlush = true
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.conn.DelRule(masqRule); err != nil {
if masqRule.Handle == 0 {
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
delete(r.rules, ruleKey+snatSuffix)
} else if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
} else {
needsFlush = true
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
if needsFlush {
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
}
if merr == nil {
@@ -1757,16 +1831,25 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}

View File

@@ -18,6 +18,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/id"
)
const (
@@ -719,3 +720,137 @@ func deleteWorkTable() {
}
}
}
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := "stale-rule-that-does-not-exist"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
err = r.refreshRulesMap()
require.NoError(t, err)
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
realRule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "real rule should still exist after refresh")
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
}
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
pair := firewall.RouterPair{
ID: "staletest",
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true,
}
rtr := manager.router
// First add succeeds
err = rtr.AddNatRule(pair)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(pair))
})
// Corrupt the handle to simulate stale state
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := rtr.rules[natRuleKey]; exists {
rule.Handle = 0
}
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
if rule, exists := rtr.rules[inverseKey]; exists {
rule.Handle = 0
}
// Adding the same rule again should succeed despite stale handles
err = rtr.AddNatRule(pair)
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
// Verify rules exist in kernel
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err)
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found++
}
}
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
}

View File

@@ -3,12 +3,6 @@
package uspfilter
import (
"context"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
m.resetState()
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)

View File

@@ -1,12 +1,9 @@
package uspfilter
import (
"context"
"fmt"
"net/netip"
"os/exec"
"syscall"
"time"
log "github.com/sirupsen/logrus"
@@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
m.resetState()
if !isWindowsFirewallReachable() {
return nil

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"context"
"encoding/binary"
"errors"
"fmt"
@@ -12,11 +13,13 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
@@ -24,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -89,6 +93,7 @@ type Manager struct {
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
@@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
@@ -480,11 +486,15 @@ func (m *Manager) addRouteFiltering(
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
ruleID := uuid.New().String()
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}
rule := RouteRule{
// TODO: consolidate these IDs
id: ruleID,
id: string(ruleKey),
mgmtId: id,
sources: sources,
dstSet: destination.Set,
@@ -499,6 +509,7 @@ func (m *Manager) addRouteFiltering(
m.routeRules = append(m.routeRules, &rule)
m.routeRules.Sort()
m.routeRulesMap[ruleKey] = &rule
return &rule, nil
}
@@ -515,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteRouteRule(rule)
}
ruleID := rule.ID()
ruleKey := nbid.RuleID(rule.ID())
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == ruleID
return r.id == string(ruleKey)
})
if idx < 0 {
return fmt.Errorf("route rule not found: %s", ruleID)
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
}
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
delete(m.routeRulesMap, ruleKey)
return nil
}
@@ -570,6 +586,40 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
maps.Clear(m.outgoingRules)
maps.Clear(m.incomingDenyRules)
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
if m.udpTracker != nil {
m.udpTracker.Close()
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
}
}
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
if m.nativeFirewall == nil {

View File

@@ -0,0 +1,376 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
// filtering rule twice returns the same rule ID (idempotent behavior).
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{
netip.MustParsePrefix("100.64.1.0/24"),
netip.MustParsePrefix("100.64.2.0/24"),
}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add rule first time
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule1)
// Add the same rule again
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule2)
// These should be the same (idempotent) like nftables/iptables implementations
assert.Equal(t, rule1.ID(), rule2.ID(),
"Adding the same rule twice should return the same rule ID (idempotent)")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule)")
}
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
// different parameters get distinct IDs.
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
// Add first rule
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
// Add different rule (different destination)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-2"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
assert.NotEqual(t, rule1.ID(), rule2.ID(),
"Different rules should have different IDs")
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
}
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
// rule during a network map update does not disrupt existing traffic.
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
require.True(t, pass, "Traffic should pass with rule in place")
// Re-add same rule (simulates network map update)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
// would remove the only matching rule and cause a traffic gap.
if rule1.ID() != rule2.ID() {
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
}
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.True(t, passAfter,
"Traffic should still pass after rule update - no gap should occur")
}
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
// exactly one drop rule for the WireGuard network prefix, and calling it again
// returns the same rule without duplicating.
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call blockInvalidRouted directly multiple times
rule1, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule1)
rule2, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule2)
rule3, err := manager.blockInvalidRouted(ifaceMock)
require.NoError(t, err)
require.NotNil(t, rule3)
// All should return the same rule
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
// Should have exactly 1 route rule
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
// Verify the rule blocks traffic to the WG network
srcIP := netip.MustParseAddr("10.0.0.1")
dstIP := netip.MustParseAddr("100.64.0.50")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
}
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
// EnableRouting multiple times (as happens on each route update) does not
// accumulate duplicate block rules in the routeRules slice.
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
// Call EnableRouting multiple times (simulating repeated route updates)
for i := 0; i < 5; i++ {
require.NoError(t, manager.EnableRouting())
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 1, ruleCount,
"Repeated EnableRouting should not accumulate block rules")
}
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
// rule multiple times does not create duplicate entries.
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Simulate 5 network map updates with the same route rule
for i := 0; i < 5; i++ {
rule, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
}
manager.mutex.RLock()
ruleCount := len(manager.routeRules)
manager.mutex.RUnlock()
assert.Equal(t, 2, ruleCount,
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
}
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
// after adding it multiple times works correctly.
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
manager := setupTestManager(t)
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add same rule twice
rule1, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
rule2, err := manager.AddRouteFiltering(
[]byte("policy-1"),
sources,
destination,
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
// Delete using first reference
err = manager.DeleteRouteRule(rule1)
require.NoError(t, err)
// Verify traffic no longer passes
srcIP := netip.MustParseAddr("100.64.1.5")
dstIP := netip.MustParseAddr("192.168.1.10")
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
assert.False(t, pass, "Traffic should not pass after rule deletion")
}
func setupTestManager(t *testing.T) *Manager {
t.Helper()
ctrl := gomock.NewController(t)
dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
wgNet := netip.MustParsePrefix("100.64.0.1/16")
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: wgNet.Addr(),
Network: wgNet,
}
},
GetDeviceFunc: func() *device.FilteredDevice {
return &device.FilteredDevice{Device: dev}
},
GetWGDeviceFunc: func() *wgdevice.Device {
return &wgdevice.Device{}
},
}
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.EnableRouting())
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
return manager
}

View File

@@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) {
}
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
// to the deny map and can be cleanly deleted without leaving orphans.
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Add multiple deny rules for different ports
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
require.NoError(t, err)
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
// Delete the first deny rule
err = m.DeletePeerRule(rule1[0])
require.NoError(t, err)
m.mutex.RLock()
denyCount = len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
// Delete the second deny rule
err = m.DeletePeerRule(rule2[0])
require.NoError(t, err)
m.mutex.RLock()
_, exists := m.incomingDenyRules[addr]
m.mutex.RUnlock()
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
}
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
// peer rules (simulating network map updates) does not leak rules in the maps.
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
addr := netip.MustParseAddr("192.168.1.1")
// Simulate 10 network map updates: add rule, delete old, add new
for i := 0; i < 10; i++ {
// Add a deny rule
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
// Add an allow rule
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Delete them (simulating ACL manager cleanup)
for _, r := range rules {
require.NoError(t, m.DeletePeerRule(r))
}
for _, r := range allowRules {
require.NoError(t, m.DeletePeerRule(r))
}
}
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
allowCount := len(m.incomingRules[addr])
m.mutex.RUnlock()
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
}
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
// IP are stored in separate maps and don't interfere with each other.
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, m.Close(nil))
}()
ip := net.ParseIP("192.168.1.1")
// Add allow rule for port 80
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err)
// Add deny rule for port 22
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
require.NoError(t, err)
addr := netip.MustParseAddr("192.168.1.1")
m.mutex.RLock()
allowCount := len(m.incomingRules[addr])
denyCount := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
// Delete allow rule should not affect deny rule
err = m.DeletePeerRule(allowRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyCountAfter := len(m.incomingDenyRules[addr])
m.mutex.RUnlock()
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
// Delete deny rule
err = m.DeletePeerRule(denyRule[0])
require.NoError(t, err)
m.mutex.RLock()
_, denyExists := m.incomingDenyRules[addr]
_, allowExists := m.incomingRules[addr]
m.mutex.RUnlock()
require.False(t, denyExists, "Deny rules should be empty")
require.False(t, allowExists, "Allow rules should be empty")
}
func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },

View File

@@ -29,8 +29,9 @@ type PacketFilter interface {
type FilteredDevice struct {
tun.Device
filter PacketFilter
mutex sync.RWMutex
filter PacketFilter
mutex sync.RWMutex
closeOnce sync.Once
}
// newDeviceFilter constructor function
@@ -40,6 +41,20 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
}
}
// Close closes the underlying tun device exactly once.
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
// and multiple code paths can trigger Close on the same device.
func (d *FilteredDevice) Close() error {
var err error
d.closeOnce.Do(func() {
err = d.Device.Close()
})
if err != nil {
return err
}
return nil
}
// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {

View File

@@ -82,7 +82,9 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
if cErr := tunIface.Close(); cErr != nil {
log.Debugf("failed to close tun device: %v", cErr)
}
return nil, fmt.Errorf("error configuring interface: %s", err)
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
@@ -228,6 +229,10 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}
if nbnetstack.IsEnabled() {
return errors.FormatErrorOrNil(result)
}
if err := w.waitUntilRemoved(); err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
if err := w.Destroy(); err != nil {

View File

@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
}
}()
return nsTunDev, tunNet, nil
return t.tundev, tunNet, nil
}
func (t *NetStackTun) Close() error {

View File

@@ -189,6 +189,212 @@ func TestDefaultManagerStateless(t *testing.T) {
})
}
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
// This tests the full ACL manager -> uspfilter integration.
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// Apply the same rules 5 times (simulating repeated network map updates)
for i := 0; i < 5; i++ {
acl.ApplyFiltering(networkMap, false)
}
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
assert.Equal(t, 3, len(acl.peerRulesPairs),
"Should have exactly 3 rule pairs after 5 identical updates")
}
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
// up when they're removed from the network map in a subsequent update.
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: add deny and accept rules
networkMap1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap1, false)
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
// Second update: remove the deny rule, keep only accept
networkMap2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap2, false)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Should have 1 rule after removing deny rule")
// Third update: remove all rules
networkMap3 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{},
FirewallRulesIsEmpty: true,
}
acl.ApplyFiltering(networkMap3, false)
assert.Equal(t, 0, len(acl.peerRulesPairs),
"Should have 0 rules after removing all rules")
}
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
// accept to deny (or vice versa), the old rule is properly removed and the new
// one added without leaking.
func TestRuleUpdateChangingAction(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
// First update: accept rule
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, 1, len(acl.peerRulesPairs))
// Second update: change to deny (same IP/port/proto, different action)
networkMap.FirewallRules = []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
}
acl.ApplyFiltering(networkMap, false)
// Should still have exactly 1 rule (the old accept removed, new deny added)
assert.Equal(t, 1, len(acl.peerRulesPairs),
"Changing action should result in exactly 1 rule, not 2")
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -244,7 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: device.WireGuardModuleIsLoaded(),
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
FQDN: loginResp.GetPeerConfig().GetFqdn(),
}
c.statusRecorder.UpdateLocalPeerState(localPeerState)

View File

@@ -6,7 +6,9 @@ import (
"fmt"
"net/netip"
"net/url"
"os"
"runtime"
"strconv"
"strings"
"sync"
@@ -27,6 +29,8 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
)
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
type ReadyListener interface {
OnReady()
@@ -439,6 +443,17 @@ func (s *DefaultServer) SearchDomains() []string {
// ProbeAvailability tests each upstream group's servers for availability
// and deactivates the group if no server responds
func (s *DefaultServer) ProbeAvailability() {
if val := os.Getenv(envSkipDNSProbe); val != "" {
skipProbe, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
}
if skipProbe {
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
return
}
}
var wg sync.WaitGroup
for _, mux := range s.dnsMuxMap {
wg.Add(1)

View File

@@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
return nberrors.FormatErrorOrNil(result)
}
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
if len(query.Question) == 0 {
return nil
return
}
question := query.Question[0]
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
qname := strings.ToLower(question.Name)
domain := strings.ToLower(question.Name)
logger.Tracef("question: domain=%s type=%s class=%s",
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
f.writeResponse(logger, w, resp, qname, startTime)
return
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
}
return nil
f.writeResponse(logger, w, resp, qname, startTime)
return
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
return nil
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
return
}
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
f.cache.set(domain, question.Qtype, result.IPs)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
f.cache.set(qname, question.Qtype, result.IPs)
return resp
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
}
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
type udpResponseWriter struct {
dns.ResponseWriter
query *dns.Msg
}
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
opt := u.query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
maxSize = int(opt.UDPSize())
}
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
return u.ResponseWriter.WriteMsg(resp)
}
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
@@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
opt := query.IsEdns0()
maxSize := dns.MinMsgSize
if opt != nil {
// client advertised a larger EDNS0 buffer
maxSize = int(opt.UDPSize())
}
// if our response is too big, truncate and set the TC bit
if resp.Len() > maxSize {
resp.Truncate(maxSize)
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
}
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
@@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
"dns_id": fmt.Sprintf("%04x", query.Id),
})
resp := f.handleDNSQuery(logger, w, query)
if resp == nil {
return
}
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)
return
}
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
f.handleDNSQuery(logger, w, query, startTime)
}
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
@@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError(
resp *dns.Msg,
domain string,
result resutil.LookupResult,
startTime time.Time,
) {
qType := question.Qtype
qTypeName := dns.TypeToString[qType]
@@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError(
// NotFound: cache negative result and respond
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
f.cache.set(domain, question.Qtype, nil)
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
f.writeResponse(logger, w, resp, domain, startTime)
return
}
@@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError(
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
resp.Rcode = dns.RcodeSuccess
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write cached DNS response: %v", writeErr)
}
f.writeResponse(logger, w, resp, domain, startTime)
return
}
@@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError(
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
resp.Rcode = verifyResult.Rcode
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
f.writeResponse(logger, w, resp, domain, startTime)
return
}
}
@@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError(
// No cache or verification failed. Log with or without the server field for more context.
var dnsErr *net.DNSError
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
} else {
logger.Warnf(errResolveFailed, domain, result.Err)
}
// Write final failure response.
if writeErr := w.WriteMsg(resp); writeErr != nil {
logger.Errorf("failed to write failure DNS response: %v", writeErr)
}
f.writeResponse(logger, w, resp, domain, startTime)
}
// getMatchingEntries retrieves the resource IDs for a given domain.

View File

@@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
@@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
require.NotNil(t, resp, "Expected response")
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
@@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
// Verify response
resp := mockWriter.GetLastResponse()
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else if resp != nil {
} else {
require.NotNil(t, resp, "Expected response")
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
@@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
// Verify response contains all IPs
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
@@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
},
}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
@@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
// Second query: serve from cache after upstream failure
q2 := &dns.Msg{}
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
w2 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
require.NotNil(t, writtenResp, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
resp2 := w2.GetLastResponse()
require.NotNil(t, resp2, "expected response to be written")
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
mockResolver.AssertExpectations(t)
}
@@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q1 := &dns.Msg{}
q1.SetQuestion(mixedQuery+".", dns.TypeA)
w1 := &test.MockResponseWriter{}
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
resp1 := w1.GetLastResponse()
require.NotNil(t, resp1)
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
require.Len(t, resp1.Answer, 1)
@@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
q2 := &dns.Msg{}
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
var writtenResp *dns.Msg
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
w2 := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
require.NotNil(t, writtenResp)
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
require.Len(t, writtenResp.Answer, 1)
resp2 := w2.GetLastResponse()
require.NotNil(t, resp2)
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
require.Len(t, resp2.Answer, 1)
mockResolver.AssertExpectations(t)
}
@@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
@@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
writtenResp = resp
}
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
if tt.expectNoAnswer {
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
assert.Empty(t, resp.Answer, "Response should have no answer records")
}
mockResolver.AssertExpectations(t)
@@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
query := &dns.Msg{}
// Don't set any question
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
}

View File

@@ -543,11 +543,12 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.shutdownWg.Add(1)
wgIfaceName := e.wgInterface.Name()
go func() {
defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.triggerClientRestart()
} else if err != nil {
@@ -828,6 +829,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
log.Infof("sync finished in %s", time.Since(started))
}()
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -1017,7 +1022,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.wgInterface.Address().String()
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
state.KernelInterface = device.WireGuardModuleIsLoaded()
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
state.FQDN = conf.GetFqdn()
e.statusRecorder.UpdateLocalPeerState(state)

View File

@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
@@ -94,6 +95,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
// updateSSHClientConfig updates the SSH client configuration with peer information
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
if netstack.IsEnabled() {
return nil
}
peerInfo := e.extractPeerSSHInfo(remotePeers)
if len(peerInfo) == 0 {
log.Debug("no SSH-enabled peers found, skipping SSH config update")
@@ -216,6 +221,10 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
func (e *Engine) cleanupSSHConfig() {
if netstack.IsEnabled() {
return
}
configMgr := sshconfig.New()
if err := configMgr.RemoveSSHClientConfig(); err != nil {

View File

@@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
@@ -74,12 +75,13 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
return NewUDPListener(m.wgIface, peerCfg)
}
// BindListener is only used on Windows and JS platforms:
// BindListener is used on Windows, JS, and netstack platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// BindListener bypasses this by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
// - Netstack: Allows multiple instances on the same host without port conflicts.
// BindListener bypasses these issues by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
return NewUDPListener(m.wgIface, peerCfg)
}

View File

@@ -2,6 +2,7 @@ package ice
import (
"context"
"fmt"
"sync"
"time"
@@ -32,24 +33,6 @@ type ThreadSafeAgent struct {
once sync.Once
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
done := make(chan error, 1)
go func() {
done <- a.Agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) {
iceKeepAlive := iceKeepAlive()
iceDisconnectedTimeout := iceDisconnectedTimeout()
@@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c
return nil, err
}
if agent == nil {
return nil, fmt.Errorf("ice.NewAgent returned nil agent without error")
}
return &ThreadSafeAgent{Agent: agent}, nil
}
func (a *ThreadSafeAgent) Close() error {
var err error
a.once.Do(func() {
// Defensive check to prevent nil pointer dereference
// This can happen during sleep/wake transitions or memory corruption scenarios
// github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?)
// [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c]
agent := a.Agent
if agent == nil {
log.Warnf("ICE agent is nil during close, skipping")
return
}
done := make(chan error, 1)
go func() {
done <- agent.Close()
}()
select {
case err = <-done:
case <-time.After(iceAgentCloseTimeout):
log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout)
err = nil
}
})
return err
}
func GenerateICECredentials() (string, string, error) {
ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
if err != nil {

View File

@@ -107,8 +107,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
}
w.log.Debugf("agent already exists, recreate the connection")
w.agentDialerCancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
if w.agent != nil {
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
sessionID, err := NewICESessionID()

View File

@@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
}
if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultManagementURL)
log.Infof("using default Admin URL %s", DefaultAdminURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return false, err

View File

@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
}
func (m *DefaultManager) setupRefCounters(useNoop bool) {
var once sync.Once
var wgIface *net.Interface
toInterface := func() *net.Interface {
once.Do(func() {
wgIface = m.wgInterface.ToInterface()
})
return wgIface
}
m.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
},
func(prefix netip.Prefix, _ struct{}) error {
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
},
)

View File

@@ -4,16 +4,17 @@ package systemops
import (
"strings"
"syscall"
"golang.org/x/sys/unix"
)
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
if routeMessageFlags&unix.RTF_UP == 0 {
return true
}
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 {
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 {
return true
}
@@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string {
var flagStrs []string
if flags&syscall.RTF_UP != 0 {
if flags&unix.RTF_UP != 0 {
flagStrs = append(flagStrs, "U")
}
if flags&syscall.RTF_GATEWAY != 0 {
if flags&unix.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G")
}
if flags&syscall.RTF_HOST != 0 {
if flags&unix.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H")
}
if flags&syscall.RTF_REJECT != 0 {
if flags&unix.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R")
}
if flags&syscall.RTF_DYNAMIC != 0 {
if flags&unix.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D")
}
if flags&syscall.RTF_MODIFIED != 0 {
if flags&unix.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M")
}
if flags&syscall.RTF_STATIC != 0 {
if flags&unix.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S")
}
if flags&syscall.RTF_LLINFO != 0 {
if flags&unix.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L")
}
if flags&syscall.RTF_LOCAL != 0 {
if flags&unix.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l")
}
if flags&syscall.RTF_BLACKHOLE != 0 {
if flags&unix.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B")
}
if flags&syscall.RTF_CLONING != 0 {
if flags&unix.RTF_CLONING != 0 {
flagStrs = append(flagStrs, "C")
}
if flags&syscall.RTF_WASCLONED != 0 {
if flags&unix.RTF_WASCLONED != 0 {
flagStrs = append(flagStrs, "W")
}
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 {
return "-"

View File

@@ -4,17 +4,18 @@ package systemops
import (
"strings"
"syscall"
"golang.org/x/sys/unix"
)
// filterRoutesByFlags returns true if the route message should be ignored based on its flags.
func filterRoutesByFlags(routeMessageFlags int) bool {
if routeMessageFlags&syscall.RTF_UP == 0 {
if routeMessageFlags&unix.RTF_UP == 0 {
return true
}
// NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 {
// NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0
if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 {
return true
}
@@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool {
func formatBSDFlags(flags int) string {
var flagStrs []string
if flags&syscall.RTF_UP != 0 {
if flags&unix.RTF_UP != 0 {
flagStrs = append(flagStrs, "U")
}
if flags&syscall.RTF_GATEWAY != 0 {
if flags&unix.RTF_GATEWAY != 0 {
flagStrs = append(flagStrs, "G")
}
if flags&syscall.RTF_HOST != 0 {
if flags&unix.RTF_HOST != 0 {
flagStrs = append(flagStrs, "H")
}
if flags&syscall.RTF_REJECT != 0 {
if flags&unix.RTF_REJECT != 0 {
flagStrs = append(flagStrs, "R")
}
if flags&syscall.RTF_DYNAMIC != 0 {
if flags&unix.RTF_DYNAMIC != 0 {
flagStrs = append(flagStrs, "D")
}
if flags&syscall.RTF_MODIFIED != 0 {
if flags&unix.RTF_MODIFIED != 0 {
flagStrs = append(flagStrs, "M")
}
if flags&syscall.RTF_STATIC != 0 {
if flags&unix.RTF_STATIC != 0 {
flagStrs = append(flagStrs, "S")
}
if flags&syscall.RTF_LLINFO != 0 {
if flags&unix.RTF_LLINFO != 0 {
flagStrs = append(flagStrs, "L")
}
if flags&syscall.RTF_LOCAL != 0 {
if flags&unix.RTF_LOCAL != 0 {
flagStrs = append(flagStrs, "l")
}
if flags&syscall.RTF_BLACKHOLE != 0 {
if flags&unix.RTF_BLACKHOLE != 0 {
flagStrs = append(flagStrs, "B")
}
// Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0
if flags&unix.RTF_PROTO1 != 0 {
flagStrs = append(flagStrs, "1")
}
if flags&unix.RTF_PROTO2 != 0 {
flagStrs = append(flagStrs, "2")
}
if flags&unix.RTF_PROTO3 != 0 {
flagStrs = append(flagStrs, "3")
}
if len(flagStrs) == 0 {
return "-"

5
combined/Dockerfile Normal file
View File

@@ -0,0 +1,5 @@
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-server" ]
CMD ["--config", "/etc/netbird/config.yaml"]
COPY netbird-server /go/bin/netbird-server

715
combined/cmd/config.go Normal file
View File

@@ -0,0 +1,715 @@
package cmd
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"path"
"strings"
"time"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/crypt"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
// CombinedConfig is the root configuration for the combined server.
// The combined server is primarily a Management server with optional embedded
// Signal, Relay, and STUN services.
//
// Architecture:
// - Management: Always runs locally (this IS the management server)
// - Signal: Runs locally by default; disabled if server.signalUri is set
// - Relay: Runs locally by default; disabled if server.relays is set
// - STUN: Runs locally on port 3478 by default; disabled if server.stuns is set
//
// All user-facing settings are under "server". The relay/signal/management
// fields are internal and populated automatically from server settings.
type CombinedConfig struct {
Server ServerConfig `yaml:"server"`
// Internal configs - populated from Server settings, not user-configurable
Relay RelayConfig `yaml:"-"`
Signal SignalConfig `yaml:"-"`
Management ManagementConfig `yaml:"-"`
}
// ServerConfig contains server-wide settings
// In simplified mode, this contains all configuration
type ServerConfig struct {
ListenAddress string `yaml:"listenAddress"`
MetricsPort int `yaml:"metricsPort"`
HealthcheckAddress string `yaml:"healthcheckAddress"`
LogLevel string `yaml:"logLevel"`
LogFile string `yaml:"logFile"`
TLS TLSConfig `yaml:"tls"`
// Simplified config fields (used when relay/signal/management sections are omitted)
ExposedAddress string `yaml:"exposedAddress"` // Public address with protocol (e.g., "https://example.com:443")
StunPorts []int `yaml:"stunPorts"` // STUN ports (empty to disable local STUN)
AuthSecret string `yaml:"authSecret"` // Shared secret for relay authentication
DataDir string `yaml:"dataDir"` // Data directory for all services
// External service overrides (simplified mode)
// When these are set, the corresponding local service is NOT started
// and these values are used for client configuration instead
Stuns []HostConfig `yaml:"stuns"` // External STUN servers (disables local STUN)
Relays RelaysConfig `yaml:"relays"` // External relay servers (disables local relay)
SignalURI string `yaml:"signalUri"` // External signal server (disables local signal)
// Management settings (simplified mode)
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
Auth AuthConfig `yaml:"auth"`
Store StoreConfig `yaml:"store"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}
// TLSConfig contains TLS/HTTPS settings
type TLSConfig struct {
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
LetsEncrypt LetsEncryptConfig `yaml:"letsencrypt"`
}
// LetsEncryptConfig contains Let's Encrypt settings
type LetsEncryptConfig struct {
Enabled bool `yaml:"enabled"`
DataDir string `yaml:"dataDir"`
Domains []string `yaml:"domains"`
Email string `yaml:"email"`
AWSRoute53 bool `yaml:"awsRoute53"`
}
// RelayConfig contains relay service settings
type RelayConfig struct {
Enabled bool `yaml:"enabled"`
ExposedAddress string `yaml:"exposedAddress"`
AuthSecret string `yaml:"authSecret"`
LogLevel string `yaml:"logLevel"`
Stun StunConfig `yaml:"stun"`
}
// StunConfig contains embedded STUN service settings
type StunConfig struct {
Enabled bool `yaml:"enabled"`
Ports []int `yaml:"ports"`
LogLevel string `yaml:"logLevel"`
}
// SignalConfig contains signal service settings
type SignalConfig struct {
Enabled bool `yaml:"enabled"`
LogLevel string `yaml:"logLevel"`
}
// ManagementConfig contains management service settings
type ManagementConfig struct {
Enabled bool `yaml:"enabled"`
LogLevel string `yaml:"logLevel"`
DataDir string `yaml:"dataDir"`
DnsDomain string `yaml:"dnsDomain"`
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
DisableDefaultPolicy bool `yaml:"disableDefaultPolicy"`
Auth AuthConfig `yaml:"auth"`
Stuns []HostConfig `yaml:"stuns"`
Relays RelaysConfig `yaml:"relays"`
SignalURI string `yaml:"signalUri"`
Store StoreConfig `yaml:"store"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}
// AuthConfig contains authentication/identity provider settings
type AuthConfig struct {
Issuer string `yaml:"issuer"`
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
Storage AuthStorageConfig `yaml:"storage"`
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
}
// AuthStorageConfig contains auth storage settings
type AuthStorageConfig struct {
Type string `yaml:"type"`
File string `yaml:"file"`
}
// AuthOwnerConfig contains initial admin user settings
type AuthOwnerConfig struct {
Email string `yaml:"email"`
Password string `yaml:"password"`
}
// HostConfig represents a STUN/TURN/Signal host
type HostConfig struct {
URI string `yaml:"uri"`
Proto string `yaml:"proto,omitempty"` // udp, dtls, tcp, http, https - defaults based on URI scheme
Username string `yaml:"username,omitempty"`
Password string `yaml:"password,omitempty"`
}
// RelaysConfig contains external relay server settings for clients
type RelaysConfig struct {
Addresses []string `yaml:"addresses"`
CredentialsTTL string `yaml:"credentialsTTL"`
Secret string `yaml:"secret"`
}
// StoreConfig contains database settings
type StoreConfig struct {
Engine string `yaml:"engine"`
EncryptionKey string `yaml:"encryptionKey"`
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
}
// ReverseProxyConfig contains reverse proxy settings
type ReverseProxyConfig struct {
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
TrustedPeers []string `yaml:"trustedPeers"`
}
// DefaultConfig returns a CombinedConfig with default values
func DefaultConfig() *CombinedConfig {
return &CombinedConfig{
Server: ServerConfig{
ListenAddress: ":443",
MetricsPort: 9090,
HealthcheckAddress: ":9000",
LogLevel: "info",
LogFile: "console",
StunPorts: []int{3478},
DataDir: "/var/lib/netbird/",
Auth: AuthConfig{
Storage: AuthStorageConfig{
Type: "sqlite3",
},
},
Store: StoreConfig{
Engine: "sqlite",
},
},
Relay: RelayConfig{
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
Stun: StunConfig{
Enabled: false,
Ports: []int{3478},
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
},
},
Signal: SignalConfig{
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
},
Management: ManagementConfig{
DataDir: "/var/lib/netbird/",
Auth: AuthConfig{
Storage: AuthStorageConfig{
Type: "sqlite3",
},
},
Relays: RelaysConfig{
CredentialsTTL: "12h",
},
Store: StoreConfig{
Engine: "sqlite",
},
},
}
}
// hasRequiredSettings returns true if the configuration has the required server settings
func (c *CombinedConfig) hasRequiredSettings() bool {
return c.Server.ExposedAddress != ""
}
// parseExposedAddress extracts protocol, host, and host:port from the exposed address
// Input format: "https://example.com:443" or "http://example.com:8080" or "example.com:443"
// Returns: protocol ("https" or "http"), hostname only, and host:port
func parseExposedAddress(exposedAddress string) (protocol, hostname, hostPort string) {
// Default to https if no protocol specified
protocol = "https"
hostPort = exposedAddress
// Check for protocol prefix
if strings.HasPrefix(exposedAddress, "https://") {
protocol = "https"
hostPort = strings.TrimPrefix(exposedAddress, "https://")
} else if strings.HasPrefix(exposedAddress, "http://") {
protocol = "http"
hostPort = strings.TrimPrefix(exposedAddress, "http://")
}
// Extract hostname (without port)
hostname = hostPort
if host, _, err := net.SplitHostPort(hostPort); err == nil {
hostname = host
}
return protocol, hostname, hostPort
}
// ApplySimplifiedDefaults populates internal relay/signal/management configs from server settings.
// Management is always enabled. Signal, Relay, and STUN are enabled unless external
// overrides are configured (server.signalUri, server.relays, server.stuns).
func (c *CombinedConfig) ApplySimplifiedDefaults() {
if !c.hasRequiredSettings() {
return
}
// Parse exposed address to extract protocol and hostname
exposedProto, exposedHost, exposedHostPort := parseExposedAddress(c.Server.ExposedAddress)
// Check for external service overrides
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
hasExternalSignal := c.Server.SignalURI != ""
hasExternalStuns := len(c.Server.Stuns) > 0
// Default stunPorts to [3478] if not specified and no external STUN
if len(c.Server.StunPorts) == 0 && !hasExternalStuns {
c.Server.StunPorts = []int{3478}
}
c.applyRelayDefaults(exposedProto, exposedHostPort, hasExternalRelay, hasExternalStuns)
c.applySignalDefaults(hasExternalSignal)
c.applyManagementDefaults(exposedHost)
// Auto-configure client settings (stuns, relays, signalUri)
c.autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort, hasExternalStuns, hasExternalRelay, hasExternalSignal)
}
// applyRelayDefaults configures the relay service if no external relay is configured.
func (c *CombinedConfig) applyRelayDefaults(exposedProto, exposedHostPort string, hasExternalRelay, hasExternalStuns bool) {
if hasExternalRelay {
return
}
c.Relay.Enabled = true
relayProto := "rel"
if exposedProto == "https" {
relayProto = "rels"
}
c.Relay.ExposedAddress = fmt.Sprintf("%s://%s", relayProto, exposedHostPort)
c.Relay.AuthSecret = c.Server.AuthSecret
if c.Relay.LogLevel == "" {
c.Relay.LogLevel = c.Server.LogLevel
}
// Enable local STUN only if no external STUN servers and stunPorts are configured
if !hasExternalStuns && len(c.Server.StunPorts) > 0 {
c.Relay.Stun.Enabled = true
c.Relay.Stun.Ports = c.Server.StunPorts
if c.Relay.Stun.LogLevel == "" {
c.Relay.Stun.LogLevel = c.Server.LogLevel
}
}
}
// applySignalDefaults configures the signal service if no external signal is configured.
func (c *CombinedConfig) applySignalDefaults(hasExternalSignal bool) {
if hasExternalSignal {
return
}
c.Signal.Enabled = true
if c.Signal.LogLevel == "" {
c.Signal.LogLevel = c.Server.LogLevel
}
}
// applyManagementDefaults configures the management service (always enabled).
func (c *CombinedConfig) applyManagementDefaults(exposedHost string) {
c.Management.Enabled = true
if c.Management.LogLevel == "" {
c.Management.LogLevel = c.Server.LogLevel
}
if c.Management.DataDir == "" || c.Management.DataDir == "/var/lib/netbird/" {
c.Management.DataDir = c.Server.DataDir
}
c.Management.DnsDomain = exposedHost
c.Management.DisableAnonymousMetrics = c.Server.DisableAnonymousMetrics
c.Management.DisableGeoliteUpdate = c.Server.DisableGeoliteUpdate
// Copy auth config from server if management auth issuer is not set
if c.Management.Auth.Issuer == "" && c.Server.Auth.Issuer != "" {
c.Management.Auth = c.Server.Auth
}
// Copy store config from server if not set
if c.Management.Store.Engine == "" || c.Management.Store.Engine == "sqlite" {
if c.Server.Store.Engine != "" {
c.Management.Store = c.Server.Store
}
}
// Copy reverse proxy config from server
if len(c.Server.ReverseProxy.TrustedHTTPProxies) > 0 || c.Server.ReverseProxy.TrustedHTTPProxiesCount > 0 || len(c.Server.ReverseProxy.TrustedPeers) > 0 {
c.Management.ReverseProxy = c.Server.ReverseProxy
}
}
// autoConfigureClientSettings sets up STUN/relay/signal URIs for clients
// External overrides from server config take precedence over auto-generated values
func (c *CombinedConfig) autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort string, hasExternalStuns, hasExternalRelay, hasExternalSignal bool) {
// Determine relay protocol from exposed protocol
relayProto := "rel"
if exposedProto == "https" {
relayProto = "rels"
}
// Configure STUN servers for clients
if hasExternalStuns {
// Use external STUN servers from server config
c.Management.Stuns = c.Server.Stuns
} else if len(c.Server.StunPorts) > 0 && len(c.Management.Stuns) == 0 {
// Auto-configure local STUN servers for all ports
for _, port := range c.Server.StunPorts {
c.Management.Stuns = append(c.Management.Stuns, HostConfig{
URI: fmt.Sprintf("stun:%s:%d", exposedHost, port),
})
}
}
// Configure relay for clients
if hasExternalRelay {
// Use external relay config from server
c.Management.Relays = c.Server.Relays
} else if len(c.Management.Relays.Addresses) == 0 {
// Auto-configure local relay
c.Management.Relays.Addresses = []string{
fmt.Sprintf("%s://%s", relayProto, exposedHostPort),
}
}
if c.Management.Relays.Secret == "" {
c.Management.Relays.Secret = c.Server.AuthSecret
}
if c.Management.Relays.CredentialsTTL == "" {
c.Management.Relays.CredentialsTTL = "12h"
}
// Configure signal for clients
if hasExternalSignal {
// Use external signal URI from server config
c.Management.SignalURI = c.Server.SignalURI
} else if c.Management.SignalURI == "" {
// Auto-configure local signal
c.Management.SignalURI = fmt.Sprintf("%s://%s", exposedProto, exposedHostPort)
}
}
// LoadConfig loads configuration from a YAML file
func LoadConfig(configPath string) (*CombinedConfig, error) {
cfg := DefaultConfig()
if configPath == "" {
return cfg, nil
}
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
if err := yaml.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
// Populate internal configs from server settings
cfg.ApplySimplifiedDefaults()
return cfg, nil
}
// Validate validates the configuration
func (c *CombinedConfig) Validate() error {
if c.Server.ExposedAddress == "" {
return fmt.Errorf("server.exposedAddress is required")
}
if c.Server.DataDir == "" {
return fmt.Errorf("server.dataDir is required")
}
// Validate STUN ports
seen := make(map[int]bool)
for _, port := range c.Server.StunPorts {
if port <= 0 || port > 65535 {
return fmt.Errorf("invalid server.stunPorts value %d: must be between 1 and 65535", port)
}
if seen[port] {
return fmt.Errorf("duplicate STUN port %d in server.stunPorts", port)
}
seen[port] = true
}
// authSecret is required only if running local relay (no external relay configured)
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
if !hasExternalRelay && c.Server.AuthSecret == "" {
return fmt.Errorf("server.authSecret is required when running local relay")
}
return nil
}
// HasTLSCert returns true if TLS certificate files are configured
func (c *CombinedConfig) HasTLSCert() bool {
return c.Server.TLS.CertFile != "" && c.Server.TLS.KeyFile != ""
}
// HasLetsEncrypt returns true if Let's Encrypt is configured
func (c *CombinedConfig) HasLetsEncrypt() bool {
return c.Server.TLS.LetsEncrypt.Enabled &&
c.Server.TLS.LetsEncrypt.DataDir != "" &&
len(c.Server.TLS.LetsEncrypt.Domains) > 0
}
// parseExplicitProtocol parses an explicit protocol string to nbconfig.Protocol
func parseExplicitProtocol(proto string) (nbconfig.Protocol, bool) {
switch strings.ToLower(proto) {
case "udp":
return nbconfig.UDP, true
case "dtls":
return nbconfig.DTLS, true
case "tcp":
return nbconfig.TCP, true
case "http":
return nbconfig.HTTP, true
case "https":
return nbconfig.HTTPS, true
default:
return "", false
}
}
// parseStunProtocol determines protocol for STUN/TURN servers.
// stun: → UDP, stuns: → DTLS, turn: → UDP, turns: → DTLS
// Explicit proto overrides URI scheme. Defaults to UDP.
func parseStunProtocol(uri, proto string) nbconfig.Protocol {
if proto != "" {
if p, ok := parseExplicitProtocol(proto); ok {
return p
}
}
uri = strings.ToLower(uri)
switch {
case strings.HasPrefix(uri, "stuns:"):
return nbconfig.DTLS
case strings.HasPrefix(uri, "turns:"):
return nbconfig.DTLS
default:
// stun:, turn:, or no scheme - default to UDP
return nbconfig.UDP
}
}
// parseSignalProtocol determines protocol for Signal servers.
// https:// → HTTPS, http:// → HTTP. Defaults to HTTPS.
func parseSignalProtocol(uri string) nbconfig.Protocol {
uri = strings.ToLower(uri)
switch {
case strings.HasPrefix(uri, "http://"):
return nbconfig.HTTP
default:
// https:// or no scheme - default to HTTPS
return nbconfig.HTTPS
}
}
// stripSignalProtocol removes the protocol prefix from a signal URI.
// Returns just the host:port (e.g., "selfhosted2.demo.netbird.io:443").
func stripSignalProtocol(uri string) string {
uri = strings.TrimPrefix(uri, "https://")
uri = strings.TrimPrefix(uri, "http://")
return uri
}
// ToManagementConfig converts CombinedConfig to management server config
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
mgmt := c.Management
// Build STUN hosts
var stuns []*nbconfig.Host
for _, s := range mgmt.Stuns {
stuns = append(stuns, &nbconfig.Host{
URI: s.URI,
Proto: parseStunProtocol(s.URI, s.Proto),
Username: s.Username,
Password: s.Password,
})
}
// Build relay config
var relayConfig *nbconfig.Relay
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
var ttl time.Duration
if mgmt.Relays.CredentialsTTL != "" {
var err error
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
}
}
relayConfig = &nbconfig.Relay{
Addresses: mgmt.Relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: mgmt.Relays.Secret,
}
}
// Build signal config
var signalConfig *nbconfig.Host
if mgmt.SignalURI != "" {
signalConfig = &nbconfig.Host{
URI: stripSignalProtocol(mgmt.SignalURI),
Proto: parseSignalProtocol(mgmt.SignalURI),
}
}
// Build store config
storeConfig := nbconfig.StoreConfig{
Engine: types.Engine(mgmt.Store.Engine),
}
// Build reverse proxy config
reverseProxy := nbconfig.ReverseProxy{
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
}
for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies {
if prefix, err := netip.ParsePrefix(p); err == nil {
reverseProxy.TrustedHTTPProxies = append(reverseProxy.TrustedHTTPProxies, prefix)
}
}
for _, p := range mgmt.ReverseProxy.TrustedPeers {
if prefix, err := netip.ParsePrefix(p); err == nil {
reverseProxy.TrustedPeers = append(reverseProxy.TrustedPeers, prefix)
}
}
// Build HTTP config (required, even if empty)
httpConfig := &nbconfig.HttpServerConfig{}
// Build embedded IDP config (always enabled in combined server)
storageFile := mgmt.Auth.Storage.File
if storageFile == "" {
storageFile = path.Join(mgmt.DataDir, "idp.db")
}
embeddedIdP := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: mgmt.Auth.Storage.Type,
Config: idp.EmbeddedStorageTypeConfig{
File: storageFile,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
embeddedIdP.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
}
}
// Set HTTP config fields for embedded IDP
httpConfig.AuthIssuer = mgmt.Auth.Issuer
httpConfig.IdpSignKeyRefreshEnabled = mgmt.Auth.SignKeyRefreshEnabled
return &nbconfig.Config{
Stuns: stuns,
Relay: relayConfig,
Signal: signalConfig,
Datadir: mgmt.DataDir,
DataStoreEncryptionKey: mgmt.Store.EncryptionKey,
HttpConfig: httpConfig,
StoreConfig: storeConfig,
ReverseProxy: reverseProxy,
DisableDefaultPolicy: mgmt.DisableDefaultPolicy,
EmbeddedIdP: embeddedIdP,
}, nil
}
// ApplyEmbeddedIdPConfig applies embedded IdP configuration to the management config.
// This mirrors the logic in management/cmd/management.go ApplyEmbeddedIdPConfig.
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config, mgmtPort int, disableSingleAccMode bool) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
// Embedded IdP requires single account mode
if disableSingleAccMode {
return fmt.Errorf("embedded IdP requires single account mode; multiple account mode is not supported with embedded IdP")
}
// Set LocalAddress for embedded IdP, used for internal JWT validation
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
// Set storage defaults based on Datadir
if cfg.EmbeddedIdP.Storage.Type == "" {
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
}
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
}
issuer := cfg.EmbeddedIdP.Issuer
// Ensure HttpConfig exists
if cfg.HttpConfig == nil {
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
}
// Set HttpConfig values from EmbeddedIdP
cfg.HttpConfig.AuthIssuer = issuer
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
cfg.HttpConfig.AuthUserIDClaim = "sub"
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
return nil
}
// EnsureEncryptionKey generates an encryption key if not set.
// Unlike management server, we don't write back to the config file.
func EnsureEncryptionKey(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.DataStoreEncryptionKey != "" {
return nil
}
log.WithContext(ctx).Infof("DataStoreEncryptionKey is not set, generating a new key")
key, err := crypt.GenerateKey()
if err != nil {
return fmt.Errorf("failed to generate datastore encryption key: %v", err)
}
cfg.DataStoreEncryptionKey = key
keyPreview := key[:8] + "..."
log.WithContext(ctx).Warnf("DataStoreEncryptionKey generated (%s); add it to your config file under 'server.store.encryptionKey' to persist across restarts", keyPreview)
return nil
}
// LogConfigInfo logs informational messages about the loaded configuration
func LogConfigInfo(cfg *nbconfig.Config) {
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
}
if cfg.Relay != nil {
log.Infof("Relay addresses: %v", cfg.Relay.Addresses)
}
}

33
combined/cmd/pprof.go Normal file
View File

@@ -0,0 +1,33 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

711
combined/cmd/root.go Normal file
View File

@@ -0,0 +1,711 @@
package cmd
import (
"context"
"crypto/sha256"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/coder/websocket"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"go.opentelemetry.io/otel/metric"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/encryption"
mgmtServer "github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/relay/healthcheck"
relayServer "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/relay/server/listener/ws"
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/stun"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/wsproxy"
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
)
var (
configPath string
config *CombinedConfig
rootCmd = &cobra.Command{
Use: "combined",
Short: "Combined Netbird server (Management + Signal + Relay + STUN)",
Long: `Combined Netbird server for self-hosted deployments.
All services (Management, Signal, Relay) are multiplexed on a single port.
Optional STUN server runs on separate UDP ports.
Configuration is loaded from a YAML file specified with --config.`,
SilenceUsage: true,
SilenceErrors: true,
RunE: execute,
}
)
func init() {
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
_ = rootCmd.MarkPersistentFlagRequired("config")
}
func Execute() error {
return rootCmd.Execute()
}
func waitForExitSignal() {
osSigs := make(chan os.Signal, 1)
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
<-osSigs
}
func execute(cmd *cobra.Command, _ []string) error {
if err := initializeConfig(); err != nil {
return err
}
// Management is required as the base server when signal or relay are enabled
if (config.Signal.Enabled || config.Relay.Enabled) && !config.Management.Enabled {
return fmt.Errorf("management must be enabled when signal or relay are enabled (provides the base HTTP server)")
}
servers, err := createAllServers(cmd.Context(), config)
if err != nil {
return err
}
// Register services with management's gRPC server using AfterInit hook
setupServerHooks(servers, config)
// Start management server (this also starts the HTTP listener)
if servers.mgmtSrv != nil {
if err := servers.mgmtSrv.Start(cmd.Context()); err != nil {
cleanupSTUNListeners(servers.stunListeners)
return fmt.Errorf("failed to start management server: %w", err)
}
}
// Start all other servers
wg := sync.WaitGroup{}
startServers(&wg, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.metricsServer)
waitForExitSignal()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = shutdownServers(ctx, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.mgmtSrv, servers.metricsServer)
wg.Wait()
return err
}
// initializeConfig loads and validates the configuration, then initializes logging.
func initializeConfig() error {
var err error
config, err = LoadConfig(configPath)
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
if err := config.Validate(); err != nil {
return fmt.Errorf("invalid config: %w", err)
}
if err := util.InitLog(config.Server.LogLevel, config.Server.LogFile); err != nil {
return fmt.Errorf("failed to initialize log: %w", err)
}
if dsn := config.Server.Store.DSN; dsn != "" {
switch strings.ToLower(config.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
log.Infof("Starting combined NetBird server")
logConfig(config)
logEnvVars()
return nil
}
// serverInstances holds all server instances created during startup.
type serverInstances struct {
relaySrv *relayServer.Server
mgmtSrv *mgmtServer.BaseServer
signalSrv *signalServer.Server
healthcheck *healthcheck.Server
stunServer *stun.Server
stunListeners []*net.UDPConn
metricsServer *sharedMetrics.Metrics
}
// createAllServers creates all server instances based on configuration.
func createAllServers(ctx context.Context, cfg *CombinedConfig) (*serverInstances, error) {
metricsServer, err := sharedMetrics.NewServer(cfg.Server.MetricsPort, "")
if err != nil {
return nil, fmt.Errorf("failed to create metrics server: %w", err)
}
servers := &serverInstances{
metricsServer: metricsServer,
}
_, tlsSupport, err := handleTLSConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to setup TLS config: %w", err)
}
if err := servers.createRelayServer(cfg, tlsSupport); err != nil {
return nil, err
}
if err := servers.createManagementServer(ctx, cfg); err != nil {
return nil, err
}
if err := servers.createSignalServer(ctx, cfg); err != nil {
return nil, err
}
if err := servers.createHealthcheckServer(cfg); err != nil {
return nil, err
}
return servers, nil
}
func (s *serverInstances) createRelayServer(cfg *CombinedConfig, tlsSupport bool) error {
if !cfg.Relay.Enabled {
return nil
}
var err error
s.stunListeners, err = createSTUNListeners(cfg)
if err != nil {
return err
}
hashedSecret := sha256.Sum256([]byte(cfg.Relay.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
relayCfg := relayServer.Config{
Meter: s.metricsServer.Meter,
ExposedAddress: cfg.Relay.ExposedAddress,
AuthValidator: authenticator,
TLSSupport: tlsSupport,
}
s.relaySrv, err = createRelayServer(relayCfg, s.stunListeners)
if err != nil {
return err
}
log.Infof("Relay server created")
if len(s.stunListeners) > 0 {
s.stunServer = stun.NewServer(s.stunListeners, cfg.Relay.Stun.LogLevel)
}
return nil
}
func (s *serverInstances) createManagementServer(ctx context.Context, cfg *CombinedConfig) error {
if !cfg.Management.Enabled {
return nil
}
mgmtConfig, err := cfg.ToManagementConfig()
if err != nil {
return fmt.Errorf("failed to create management config: %w", err)
}
_, portStr, portErr := net.SplitHostPort(cfg.Server.ListenAddress)
if portErr != nil {
portStr = "443"
}
mgmtPort, _ := strconv.Atoi(portStr)
if err := ApplyEmbeddedIdPConfig(ctx, mgmtConfig, mgmtPort, false); err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to apply embedded IdP config: %w", err)
}
if err := EnsureEncryptionKey(ctx, mgmtConfig); err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to ensure encryption key: %w", err)
}
LogConfigInfo(mgmtConfig)
s.mgmtSrv, err = createManagementServer(cfg, mgmtConfig)
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create management server: %w", err)
}
// Inject externally-managed AppMetrics so management uses the shared metrics server
appMetrics, err := telemetry.NewAppMetricsWithMeter(ctx, s.metricsServer.Meter)
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create management app metrics: %w", err)
}
mgmtServer.Inject[telemetry.AppMetrics](s.mgmtSrv, appMetrics)
log.Infof("Management server created")
return nil
}
func (s *serverInstances) createSignalServer(ctx context.Context, cfg *CombinedConfig) error {
if !cfg.Signal.Enabled {
return nil
}
var err error
s.signalSrv, err = signalServer.NewServer(ctx, s.metricsServer.Meter, "signal_")
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create signal server: %w", err)
}
log.Infof("Signal server created")
return nil
}
func (s *serverInstances) createHealthcheckServer(cfg *CombinedConfig) error {
hCfg := healthcheck.Config{
ListenAddress: cfg.Server.HealthcheckAddress,
ServiceChecker: s.relaySrv,
}
var err error
s.healthcheck, err = createHealthCheck(hCfg, s.stunListeners)
return err
}
// setupServerHooks registers services with management's gRPC server.
func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
if servers.mgmtSrv == nil {
return
}
servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) {
grpcSrv := s.GRPCServer()
if servers.signalSrv != nil {
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
}
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
if servers.relaySrv != nil {
log.Infof("Relay WebSocket handler added (path: /relay)")
}
})
}
func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) {
if srv != nil {
instanceURL := srv.InstanceURL()
log.Infof("Relay server instance URL: %s", instanceURL.String())
log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)")
}
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start metrics server: %v", err)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start healthcheck server: %v", err)
}
}()
if stunServer != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := stunServer.Listen(); err != nil {
if errors.Is(err, stun.ErrServerClosed) {
return
}
log.Errorf("STUN server error: %v", err)
}
}()
}
}
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error {
var errs error
if err := httpHealthcheck.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
}
if stunServer != nil {
if err := stunServer.Shutdown(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
}
}
if srv != nil {
if err := srv.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
}
}
if mgmtSrv != nil {
log.Infof("shutting down management and signal servers")
if err := mgmtSrv.Stop(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close management server: %w", err))
}
}
if metricsServer != nil {
log.Infof("shutting down metrics server")
if err := metricsServer.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
}
}
return errs
}
func createHealthCheck(hCfg healthcheck.Config, stunListeners []*net.UDPConn) (*healthcheck.Server, error) {
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create healthcheck server: %w", err)
}
return httpHealthcheck, nil
}
func createRelayServer(cfg relayServer.Config, stunListeners []*net.UDPConn) (*relayServer.Server, error) {
srv, err := relayServer.NewServer(cfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create relay server: %w", err)
}
return srv, nil
}
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
for _, l := range stunListeners {
_ = l.Close()
}
}
func createSTUNListeners(cfg *CombinedConfig) ([]*net.UDPConn, error) {
var stunListeners []*net.UDPConn
if cfg.Relay.Stun.Enabled {
for _, port := range cfg.Relay.Stun.Ports {
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create STUN listener on port %d: %w", port, err)
}
stunListeners = append(stunListeners, listener)
log.Infof("STUN server listening on UDP port %d", port)
}
}
return stunListeners, nil
}
func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
tlsCfg := cfg.Server.TLS
if tlsCfg.LetsEncrypt.AWSRoute53 {
log.Debugf("using Let's Encrypt DNS resolver with Route 53 support")
r53 := encryption.Route53TLS{
DataDir: tlsCfg.LetsEncrypt.DataDir,
Email: tlsCfg.LetsEncrypt.Email,
Domains: tlsCfg.LetsEncrypt.Domains,
}
tc, err := r53.GetCertificate()
if err != nil {
return nil, false, err
}
return tc, true, nil
}
if cfg.HasLetsEncrypt() {
log.Infof("setting up TLS with Let's Encrypt")
certManager, err := encryption.CreateCertManager(tlsCfg.LetsEncrypt.DataDir, tlsCfg.LetsEncrypt.Domains...)
if err != nil {
return nil, false, fmt.Errorf("failed creating LetsEncrypt cert manager: %w", err)
}
return certManager.TLSConfig(), true, nil
}
if cfg.HasTLSCert() {
log.Debugf("using file based TLS config")
tc, err := encryption.LoadTLSConfig(tlsCfg.CertFile, tlsCfg.KeyFile)
if err != nil {
return nil, false, err
}
return tc, true, nil
}
return nil, false, nil
}
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
mgmt := cfg.Management
dnsDomain := mgmt.DnsDomain
singleAccModeDomain := dnsDomain
// Extract port from listen address
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
if err != nil {
// If no port specified, assume default
portStr = "443"
}
mgmtPort, _ := strconv.Atoi(portStr)
mgmtSrv := mgmtServer.NewServer(
mgmtConfig,
dnsDomain,
singleAccModeDomain,
mgmtPort,
cfg.Server.MetricsPort,
mgmt.DisableAnonymousMetrics,
mgmt.DisableGeoliteUpdate,
// Always enable user deletion from IDP in combined server (embedded IdP is always enabled)
true,
)
return mgmtSrv, nil
}
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
var relayAcceptFn func(conn net.Conn)
if relaySrv != nil {
relayAcceptFn = relaySrv.RelayAccept()
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
// Native gRPC traffic (HTTP/2 with gRPC content-type)
case r.ProtoMajor == 2 && (strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto")):
grpcServer.ServeHTTP(w, r)
// WebSocket proxy for Management gRPC
case r.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
wsProxy.Handler().ServeHTTP(w, r)
// WebSocket proxy for Signal gRPC
case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent:
if cfg.Signal.Enabled {
wsProxy.Handler().ServeHTTP(w, r)
} else {
http.Error(w, "Signal service not enabled", http.StatusNotFound)
}
// Relay WebSocket
case r.URL.Path == "/relay":
if relayAcceptFn != nil {
handleRelayWebSocket(w, r, relayAcceptFn, cfg)
} else {
http.Error(w, "Relay service not enabled", http.StatusNotFound)
}
// Management HTTP API (default)
default:
httpHandler.ServeHTTP(w, r)
}
})
}
// handleRelayWebSocket handles incoming WebSocket connections for the relay service
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) {
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
}
wsConn, err := websocket.Accept(w, r, acceptOptions)
if err != nil {
log.Errorf("failed to accept relay ws connection: %s", err)
return
}
connRemoteAddr := r.RemoteAddr
if r.Header.Get("X-Real-Ip") != "" && r.Header.Get("X-Real-Port") != "" {
connRemoteAddr = net.JoinHostPort(r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port"))
}
rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
log.Debugf("Relay WS client connected from: %s", rAddr)
conn := ws.NewConn(wsConn, lAddr, rAddr)
acceptFn(conn)
}
// logConfig prints all configuration parameters for debugging
func logConfig(cfg *CombinedConfig) {
log.Info("=== Configuration ===")
logServerConfig(cfg)
logComponentsConfig(cfg)
logRelayConfig(cfg)
logManagementConfig(cfg)
log.Info("=== End Configuration ===")
}
func logServerConfig(cfg *CombinedConfig) {
log.Info("--- Server ---")
log.Infof(" Listen address: %s", cfg.Server.ListenAddress)
log.Infof(" Exposed address: %s", cfg.Server.ExposedAddress)
log.Infof(" Healthcheck address: %s", cfg.Server.HealthcheckAddress)
log.Infof(" Metrics port: %d", cfg.Server.MetricsPort)
log.Infof(" Log level: %s", cfg.Server.LogLevel)
log.Infof(" Data dir: %s", cfg.Server.DataDir)
switch {
case cfg.HasTLSCert():
log.Infof(" TLS: cert=%s, key=%s", cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile)
case cfg.HasLetsEncrypt():
log.Infof(" TLS: Let's Encrypt (domains=%v)", cfg.Server.TLS.LetsEncrypt.Domains)
default:
log.Info(" TLS: disabled (using reverse proxy)")
}
}
func logComponentsConfig(cfg *CombinedConfig) {
log.Info("--- Components ---")
log.Infof(" Management: %v (log level: %s)", cfg.Management.Enabled, cfg.Management.LogLevel)
log.Infof(" Signal: %v (log level: %s)", cfg.Signal.Enabled, cfg.Signal.LogLevel)
log.Infof(" Relay: %v (log level: %s)", cfg.Relay.Enabled, cfg.Relay.LogLevel)
}
func logRelayConfig(cfg *CombinedConfig) {
if !cfg.Relay.Enabled {
return
}
log.Info("--- Relay ---")
log.Infof(" Exposed address: %s", cfg.Relay.ExposedAddress)
log.Infof(" Auth secret: %s...", maskSecret(cfg.Relay.AuthSecret))
if cfg.Relay.Stun.Enabled {
log.Infof(" STUN ports: %v (log level: %s)", cfg.Relay.Stun.Ports, cfg.Relay.Stun.LogLevel)
} else {
log.Info(" STUN: disabled")
}
}
func logManagementConfig(cfg *CombinedConfig) {
if !cfg.Management.Enabled {
return
}
log.Info("--- Management ---")
log.Infof(" Data dir: %s", cfg.Management.DataDir)
log.Infof(" DNS domain: %s", cfg.Management.DnsDomain)
log.Infof(" Store engine: %s", cfg.Management.Store.Engine)
if cfg.Server.Store.DSN != "" {
log.Infof(" Store DSN: %s", maskDSNPassword(cfg.Server.Store.DSN))
}
log.Info(" Auth (embedded IdP):")
log.Infof(" Issuer: %s", cfg.Management.Auth.Issuer)
log.Infof(" Dashboard redirect URIs: %v", cfg.Management.Auth.DashboardRedirectURIs)
log.Infof(" CLI redirect URIs: %v", cfg.Management.Auth.CLIRedirectURIs)
log.Info(" Client settings:")
log.Infof(" Signal URI: %s", cfg.Management.SignalURI)
for _, s := range cfg.Management.Stuns {
log.Infof(" STUN: %s", s.URI)
}
if len(cfg.Management.Relays.Addresses) > 0 {
log.Infof(" Relay addresses: %v", cfg.Management.Relays.Addresses)
log.Infof(" Relay credentials TTL: %s", cfg.Management.Relays.CredentialsTTL)
}
}
// logEnvVars logs all NB_ environment variables that are currently set
func logEnvVars() {
log.Info("=== Environment Variables ===")
found := false
for _, env := range os.Environ() {
if strings.HasPrefix(env, "NB_") {
key, _, _ := strings.Cut(env, "=")
value := os.Getenv(key)
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
value = maskSecret(value)
}
log.Infof(" %s=%s", key, value)
found = true
}
}
if !found {
log.Info(" (none set)")
}
log.Info("=== End Environment Variables ===")
}
// maskDSNPassword masks the password in a DSN string.
// Handles both key=value format ("password=secret") and URI format ("user:secret@host").
func maskDSNPassword(dsn string) string {
// Key=value format: "host=localhost user=nb password=secret dbname=nb"
if strings.Contains(dsn, "password=") {
parts := strings.Fields(dsn)
for i, p := range parts {
if strings.HasPrefix(p, "password=") {
parts[i] = "password=****"
}
}
return strings.Join(parts, " ")
}
// URI format: "user:password@host..."
if atIdx := strings.Index(dsn, "@"); atIdx != -1 {
prefix := dsn[:atIdx]
if colonIdx := strings.Index(prefix, ":"); colonIdx != -1 {
return prefix[:colonIdx+1] + "****" + dsn[atIdx:]
}
}
return dsn
}
// maskSecret returns first 4 chars of secret followed by "..."
func maskSecret(secret string) string {
if len(secret) <= 4 {
return "****"
}
return secret[:4] + "..."
}

View File

@@ -0,0 +1,111 @@
# NetBird Combined Server Configuration
# Copy this file to config.yaml and customize for your deployment
#
# This is a Management server with optional embedded Signal, Relay, and STUN services.
# By default, all services run locally. You can use external services instead by
# setting the corresponding override fields.
#
# Architecture:
# - Management: Always runs locally (this IS the management server)
# - Signal: Local by default; set 'signalUri' to use external (disables local)
# - Relay: Local by default; set 'relays' to use external (disables local)
# - STUN: Local on port 3478 by default; set 'stuns' to use external instead
server:
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
listenAddress: ":443"
# Public address that peers will use to connect to this server
# Used for relay connections and management DNS domain
# Format: protocol://hostname:port (e.g., https://server.mycompany.com:443)
exposedAddress: "https://server.mycompany.com:443"
# STUN server ports (defaults to [3478] if not specified; set 'stuns' to use external)
# stunPorts:
# - 3478
# Metrics endpoint port
metricsPort: 9090
# Healthcheck endpoint address
healthcheckAddress: ":9000"
# Logging configuration
logLevel: "info" # Default log level for all components: panic, fatal, error, warn, info, debug, trace
logFile: "console" # "console" or path to log file
# TLS configuration (optional)
tls:
certFile: ""
keyFile: ""
letsencrypt:
enabled: false
dataDir: ""
domains: []
email: ""
awsRoute53: false
# Shared secret for relay authentication (required when running local relay)
authSecret: "your-secret-key-here"
# Data directory for all services
dataDir: "/var/lib/netbird/"
# ============================================================================
# External Service Overrides (optional)
# Use these to point to external Signal, Relay, or STUN servers instead of
# running them locally. When set, the corresponding local service is disabled.
# ============================================================================
# External STUN servers - disables local STUN server
# stuns:
# - uri: "stun:stun.example.com:3478"
# - uri: "stun:stun.example.com:3479"
# External relay servers - disables local relay server
# relays:
# addresses:
# - "rels://relay.example.com:443"
# credentialsTTL: "12h"
# secret: "relay-shared-secret"
# External signal server - disables local signal server
# signalUri: "https://signal.example.com:443"
# ============================================================================
# Management Settings
# ============================================================================
# Metrics and updates
disableAnonymousMetrics: false
disableGeoliteUpdate: false
# Embedded authentication/identity provider (Dex) configuration (always enabled)
auth:
# OIDC issuer URL - must be publicly accessible
issuer: "https://server.mycompany.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.netbird.io/nb-auth"
- "https://app.netbird.io/nb-silent-auth"
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# Optional initial admin user
# owner:
# email: "admin@example.com"
# password: "initial-password"
# Store configuration
store:
engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# Reverse proxy settings (optional)
# reverseProxy:
# trustedHTTPProxies: []
# trustedHTTPProxiesCount: 0
# trustedPeers: []

View File

@@ -0,0 +1,115 @@
# Simplified Combined NetBird Server Configuration
# Copy this file to config.yaml and customize for your deployment
# Server-wide settings
server:
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
listenAddress: ":443"
# Metrics endpoint port
metricsPort: 9090
# Healthcheck endpoint address
healthcheckAddress: ":9000"
# Logging configuration
logLevel: "info" # panic, fatal, error, warn, info, debug, trace
logFile: "console" # "console" or path to log file
# TLS configuration (optional)
tls:
certFile: ""
keyFile: ""
letsencrypt:
enabled: false
dataDir: ""
domains: []
email: ""
awsRoute53: false
# Relay service configuration
relay:
# Enable/disable the relay service
enabled: true
# Public address that peers will use to connect to this relay
# Format: hostname:port or ip:port
exposedAddress: "relay.example.com:443"
# Shared secret for relay authentication (required when enabled)
authSecret: "your-secret-key-here"
# Log level for relay (reserved for future use, currently uses global log level)
logLevel: "info"
# Embedded STUN server (optional)
stun:
enabled: false
ports: [3478]
logLevel: "info"
# Signal service configuration
signal:
# Enable/disable the signal service
enabled: true
# Log level for signal (reserved for future use, currently uses global log level)
logLevel: "info"
# Management service configuration
management:
# Enable/disable the management service
enabled: true
# Data directory for management service
dataDir: "/var/lib/netbird/"
# DNS domain for the management server
dnsDomain: ""
# Metrics and updates
disableAnonymousMetrics: false
disableGeoliteUpdate: false
auth:
# OIDC issuer URL - must be publicly accessible
issuer: "https://management.example.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.example.com/nb-auth"
- "https://app.example.com/nb-silent-auth"
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# Optional initial admin user
# owner:
# email: "admin@example.com"
# password: "initial-password"
# External STUN servers (for client config)
stuns: []
# - uri: "stun:stun.example.com:3478"
# External relay servers (for client config)
relays:
addresses: []
# - "rels://relay.example.com:443"
credentialsTTL: "12h"
secret: ""
# External signal server URI (for client config)
signalUri: ""
# Store configuration
store:
engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# Reverse proxy settings
reverseProxy:
trustedHTTPProxies: []
trustedHTTPProxiesCount: 0
trustedPeers: []

13
combined/main.go Normal file
View File

@@ -0,0 +1,13 @@
package main
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/combined/cmd"
)
func main() {
if err := cmd.Execute(); err != nil {
log.Fatalf("failed to execute command: %v", err)
}
}

5
go.mod
View File

@@ -68,7 +68,7 @@ require (
github.com/mdlayher/socket v0.5.1
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0
@@ -243,9 +243,10 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/dtls/v3 v3.0.9 // indirect
github.com/pion/dtls/v3 v3.0.11 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pion/transport/v4 v4.0.1 // indirect
github.com/pion/turn/v4 v4.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect

10
go.sum
View File

@@ -406,8 +406,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f h1:CTBf0je/FpKr2lVSMZLak7m8aaWcS6ur4SOfhSSazFI=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260122111742-a6f99668844f/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs=
github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
@@ -451,8 +451,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
github.com/pion/dtls/v3 v3.0.11 h1:zqn8YhoAU7d9whsWLhNiQlbB8QdpJj8XQVSc5ImUons=
github.com/pion/dtls/v3 v3.0.11/go.mod h1:YEmmBYIoBsY3jmG56dsziTv/Lca9y4Om83370CXfqJ8=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
@@ -470,6 +470,8 @@ github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLh
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o=
github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM=
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=

File diff suppressed because it is too large Load Diff

View File

@@ -55,7 +55,7 @@ var (
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
config, err = LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", nbconfig.MgmtConfigPath, err)
}
@@ -133,35 +133,35 @@ var (
}
)
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
func LoadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
loadedConfig := &nbconfig.Config{}
if _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig); err != nil {
return nil, err
}
applyCommandLineOverrides(loadedConfig)
ApplyCommandLineOverrides(loadedConfig)
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
err := ApplyEmbeddedIdPConfig(ctx, loadedConfig)
if err != nil {
return nil, err
}
if err := applyOIDCConfig(ctx, loadedConfig); err != nil {
if err := ApplyOIDCConfig(ctx, loadedConfig); err != nil {
return nil, err
}
logConfigInfo(loadedConfig)
LogConfigInfo(loadedConfig)
if err := ensureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
if err := EnsureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
return nil, err
}
return loadedConfig, nil
}
// applyCommandLineOverrides applies command-line flag overrides to the config
func applyCommandLineOverrides(cfg *nbconfig.Config) {
// ApplyCommandLineOverrides applies command-line flag overrides to the config
func ApplyCommandLineOverrides(cfg *nbconfig.Config) {
if mgmtLetsencryptDomain != "" {
cfg.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
}
@@ -174,9 +174,9 @@ func applyCommandLineOverrides(cfg *nbconfig.Config) {
}
}
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// ApplyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
@@ -222,8 +222,8 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
return nil
}
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
// ApplyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func ApplyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint == "" {
return nil
@@ -249,16 +249,16 @@ func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcConfig.JwksURI, cfg.HttpConfig.AuthKeysLocation)
cfg.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if err := applyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
if err := ApplyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
return err
}
applyPKCEFlowConfig(ctx, cfg, &oidcConfig)
ApplyPKCEFlowConfig(ctx, cfg, &oidcConfig)
return nil
}
// applyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
// ApplyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
func ApplyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
if cfg.DeviceAuthorizationFlow == nil || strings.ToLower(cfg.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE) {
return nil
}
@@ -285,8 +285,8 @@ func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcCo
return nil
}
// applyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
// ApplyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
func ApplyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
if cfg.PKCEAuthorizationFlow == nil {
return
}
@@ -299,8 +299,8 @@ func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *
cfg.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
}
// logConfigInfo logs informational messages about the loaded configuration
func logConfigInfo(cfg *nbconfig.Config) {
// LogConfigInfo logs informational messages about the loaded configuration
func LogConfigInfo(cfg *nbconfig.Config) {
if cfg.EmbeddedIdP != nil {
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
}
@@ -309,8 +309,8 @@ func logConfigInfo(cfg *nbconfig.Config) {
}
}
// ensureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
func ensureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
// EnsureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
func EnsureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
if cfg.DataStoreEncryptionKey != "" {
return nil
}

View File

@@ -30,7 +30,7 @@ func Test_loadMgmtConfig(t *testing.T) {
t.Fatalf("failed to create config: %s", err)
}
cfg, err := loadMgmtConfig(context.Background(), tmpFile)
cfg, err := LoadMgmtConfig(context.Background(), tmpFile)
if err != nil {
t.Fatalf("failed to load management config: %s", err)
}

View File

@@ -247,7 +247,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update})
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
}(peer)
}
@@ -370,7 +373,10 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update})
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
return nil
}
@@ -778,6 +784,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
},
},
},
MessageType: network_map.MessageTypeNetworkMap,
})
c.peersUpdateManager.CloseChannel(ctx, peerID)

View File

@@ -25,11 +25,14 @@ func TestCreateChannel(t *testing.T) {
func TestSendUpdate(t *testing.T) {
peer := "test-sendupdate"
peersUpdater := NewPeersUpdateManager(nil)
update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 0,
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 0,
},
},
}}
MessageType: network_map.MessageTypeNetworkMap,
}
_ = peersUpdater.CreateChannel(context.Background(), peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel")
@@ -45,11 +48,14 @@ func TestSendUpdate(t *testing.T) {
peersUpdater.SendUpdate(context.Background(), peer, update1)
}
update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 10,
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: 10,
},
},
}}
MessageType: network_map.MessageTypeNetworkMap,
}
peersUpdater.SendUpdate(context.Background(), peer, update2)
timeout := time.After(5 * time.Second)

View File

@@ -4,6 +4,19 @@ import (
"github.com/netbirdio/netbird/shared/management/proto"
)
// MessageType indicates the type of update message for debouncing strategy
type MessageType int
const (
// MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall)
// These updates can be safely debounced - only the latest state matters
MessageTypeNetworkMap MessageType = iota
// MessageTypeControlConfig represents control/config updates (tokens, peer expiration)
// These updates should not be dropped as they contain time-sensitive information
MessageTypeControlConfig
)
type UpdateMessage struct {
Update *proto.SyncResponse
Update *proto.SyncResponse
MessageType MessageType
}

View File

@@ -11,7 +11,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/management/server/idp"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"golang.org/x/crypto/acme/autocert"
@@ -19,6 +18,8 @@ import (
"golang.org/x/net/http2/h2c"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/encryption"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/metrics"
@@ -138,6 +139,14 @@ func (s *BaseServer) Start(ctx context.Context) error {
go metricsWorker.Run(srvCtx)
}
// Run afterInit hooks before starting any servers
// This allows registering additional gRPC services (e.g., Signal) before Serve() is called
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
var compatListener net.Listener
if s.mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
@@ -178,12 +187,6 @@ func (s *BaseServer) Start(ctx context.Context) error {
}
}
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
@@ -255,7 +258,23 @@ func (s *BaseServer) SetContainer(key string, container any) {
log.Tracef("container with key %s set successfully", key)
}
// SetHandlerFunc allows overriding the default HTTP handler function.
// This is useful for multiplexing additional services on the same port.
func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
s.container["customHandler"] = handler
log.Tracef("custom handler set successfully")
}
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
// Check if a custom handler was set (for multiplexing additional services)
if customHandler, ok := s.GetContainer("customHandler"); ok {
if handler, ok := customHandler.(http.Handler); ok {
log.Tracef("using custom handler")
return handler
}
}
// Use default handler
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {

View File

@@ -300,7 +300,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
@@ -311,7 +311,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
return err
}
@@ -319,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
if err != nil {
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart)
return err
}
@@ -336,7 +336,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart)
}
func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
@@ -404,11 +404,20 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
// It implements a backpressure mechanism that sends the first update immediately,
// then debounces subsequent rapid updates, ensuring only the latest update is sent
// after a quiet period.
func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
// Create a debouncer for this peer connection
debouncer := NewUpdateDebouncer(1000 * time.Millisecond)
defer debouncer.Stop()
for {
select {
// condition when there are some updates
// todo set the updates channel size to 1
case update, open := <-updates:
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
@@ -416,20 +425,38 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
if !open {
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
if debouncer.ProcessUpdate(update) {
// Send immediately (first update or after quiet period)
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
}
// Timer expired - quiet period reached, send pending updates if any
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
continue
}
log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String())
for _, pendingUpdate := range pendingUpdates {
if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil {
log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err)
return err
}
}
// condition when client <-> server connection has been terminated
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return srv.Context().Err()
}
}
@@ -437,16 +464,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error {
key, err := s.secretsManager.GetWGKey()
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed processing update message")
}
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.Send(&proto.EncryptedMessage{
@@ -454,7 +481,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
@@ -486,15 +513,15 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
return nil
}
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
}
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer) {
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime)
if err != nil {
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
}

View File

@@ -242,7 +242,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeControlConfig,
})
}
func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) {
@@ -266,7 +269,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update})
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeControlConfig,
})
}
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {

View File

@@ -0,0 +1,103 @@
package grpc
import (
"time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
)
// UpdateDebouncer implements a backpressure mechanism that:
// - Sends the first update immediately
// - Coalesces rapid subsequent network map updates (only latest matters)
// - Queues control/config updates (all must be delivered)
// - Preserves the order of messages (important for control configs between network maps)
// - Ensures pending updates are sent after a quiet period
type UpdateDebouncer struct {
debounceInterval time.Duration
timer *time.Timer
pendingUpdates []*network_map.UpdateMessage // Queue that preserves order
timerC <-chan time.Time
}
// NewUpdateDebouncer creates a new debouncer with the specified interval
func NewUpdateDebouncer(interval time.Duration) *UpdateDebouncer {
return &UpdateDebouncer{
debounceInterval: interval,
}
}
// ProcessUpdate handles an incoming update and returns whether it should be sent immediately
func (d *UpdateDebouncer) ProcessUpdate(update *network_map.UpdateMessage) bool {
if d.timer == nil {
// No active debounce timer, signal to send immediately
// and start the debounce period
d.startTimer()
return true
}
// Already in debounce period, accumulate this update preserving order
// Check if we should coalesce with the last pending update
if len(d.pendingUpdates) > 0 &&
update.MessageType == network_map.MessageTypeNetworkMap &&
d.pendingUpdates[len(d.pendingUpdates)-1].MessageType == network_map.MessageTypeNetworkMap {
// Replace the last network map with this one (coalesce consecutive network maps)
d.pendingUpdates[len(d.pendingUpdates)-1] = update
} else {
// Append to the queue (preserves order for control configs and non-consecutive network maps)
d.pendingUpdates = append(d.pendingUpdates, update)
}
d.resetTimer()
return false
}
// TimerChannel returns the timer channel for select statements
func (d *UpdateDebouncer) TimerChannel() <-chan time.Time {
if d.timer == nil {
return nil
}
return d.timerC
}
// GetPendingUpdates returns and clears all pending updates after timer expiration.
// Updates are returned in the order they were received, with consecutive network maps
// already coalesced to only the latest one.
// If there were pending updates, it restarts the timer to continue debouncing.
// If there were no pending updates, it clears the timer (true quiet period).
func (d *UpdateDebouncer) GetPendingUpdates() []*network_map.UpdateMessage {
updates := d.pendingUpdates
d.pendingUpdates = nil
if len(updates) > 0 {
// There were pending updates, so updates are still coming rapidly
// Restart the timer to continue debouncing mode
if d.timer != nil {
d.timer.Reset(d.debounceInterval)
}
} else {
// No pending updates means true quiet period - return to immediate mode
d.timer = nil
d.timerC = nil
}
return updates
}
// Stop stops the debouncer and cleans up resources
func (d *UpdateDebouncer) Stop() {
if d.timer != nil {
d.timer.Stop()
d.timer = nil
d.timerC = nil
}
d.pendingUpdates = nil
}
func (d *UpdateDebouncer) startTimer() {
d.timer = time.NewTimer(d.debounceInterval)
d.timerC = d.timer.C
}
func (d *UpdateDebouncer) resetTimer() {
d.timer.Stop()
d.timer.Reset(d.debounceInterval)
}

View File

@@ -0,0 +1,587 @@
package grpc
import (
"testing"
"time"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/shared/management/proto"
)
func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
shouldSend := debouncer.ProcessUpdate(update)
if !shouldSend {
t.Error("First update should be sent immediately")
}
if debouncer.TimerChannel() == nil {
t.Error("Timer should be started after first update")
}
}
func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update should be sent immediately
if !debouncer.ProcessUpdate(update1) {
t.Error("First update should be sent immediately")
}
// Rapid subsequent updates should be coalesced
if debouncer.ProcessUpdate(update2) {
t.Error("Second rapid update should not be sent immediately")
}
if debouncer.ProcessUpdate(update3) {
t.Error("Third rapid update should not be sent immediately")
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update3 {
t.Error("Should get the last update (update3)")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
debouncer.ProcessUpdate(update1)
// Send second update within debounce period
debouncer.ProcessUpdate(update2)
// Wait for timer
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update2 {
t.Error("Should get the last update")
}
if pendingUpdates[0] == update1 {
t.Error("Should not get the first update")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
debouncer.ProcessUpdate(update1)
// Wait a bit, but not the full debounce period
time.Sleep(30 * time.Millisecond)
// Send second update - should reset timer
debouncer.ProcessUpdate(update2)
// Wait a bit more
time.Sleep(30 * time.Millisecond)
// Send third update - should reset timer again
debouncer.ProcessUpdate(update3)
// Now wait for the timer (should fire after last update's reset)
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != update3 {
t.Error("Should get the last update (update3)")
}
// Timer should be restarted since there was a pending update
if debouncer.TimerChannel() == nil {
t.Error("Timer should be restarted after sending pending update")
}
case <-time.After(150 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update3 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update sent immediately
debouncer.ProcessUpdate(update1)
// Second update coalesced
debouncer.ProcessUpdate(update2)
// Wait for timer to expire
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
t.Fatal("Should have pending update")
}
// After sending pending update, timer is restarted, so next update is NOT immediate
if debouncer.ProcessUpdate(update3) {
t.Error("Update after debounced send should not be sent immediately (timer restarted)")
}
// Wait for the restarted timer and verify update3 is pending
select {
case <-debouncer.TimerChannel():
finalUpdates := debouncer.GetPendingUpdates()
if len(finalUpdates) != 1 || finalUpdates[0] != update3 {
t.Error("Should get update3 as pending")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired for restarted timer")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_StopCleansUp(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send update to start timer
debouncer.ProcessUpdate(update)
// Stop should clean up
debouncer.Stop()
// Multiple stops should be safe
debouncer.Stop()
}
func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate high-frequency updates
var lastUpdate *network_map.UpdateMessage
sentImmediately := 0
for i := 0; i < 100; i++ {
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
lastUpdate = update
if debouncer.ProcessUpdate(update) {
sentImmediately++
}
time.Sleep(1 * time.Millisecond) // Very rapid updates
}
// Only first update should be sent immediately
if sentImmediately != 1 {
t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately)
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0] != lastUpdate {
t.Error("Should get the very last update")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 99 {
t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// Send first update
if !debouncer.ProcessUpdate(update) {
t.Error("First update should be sent immediately")
}
// Wait for timer to expire with no additional updates (true quiet period)
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 0 {
t.Error("Should have no pending updates")
}
// After true quiet period, timer should be cleared
if debouncer.TimerChannel() != nil {
t.Error("Timer should be cleared after quiet period")
}
case <-time.After(100 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
updates := make([]*network_map.UpdateMessage, 5)
for i := range updates {
updates[i] = &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
}
// First update sent immediately
debouncer.ProcessUpdate(updates[0])
// Send updates 1, 2, 3, 4 rapidly - only last one should remain pending
debouncer.ProcessUpdate(updates[1])
debouncer.ProcessUpdate(updates[2])
debouncer.ProcessUpdate(updates[3])
debouncer.ProcessUpdate(updates[4])
// Wait for debounce
<-debouncer.TimerChannel()
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 1 {
t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates))
}
if pendingUpdates[0].Update.NetworkMap.Serial != 4 {
t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
}
func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) {
debouncer := NewUpdateDebouncer(30 * time.Millisecond)
defer debouncer.Stop()
update1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
update2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{},
MessageType: network_map.MessageTypeNetworkMap,
}
// First update sent immediately
if !debouncer.ProcessUpdate(update1) {
t.Error("First update should be sent immediately")
}
// Wait for timer without sending any more updates (true quiet period)
<-debouncer.TimerChannel()
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) != 0 {
t.Error("Should have no pending updates during quiet period")
}
// After true quiet period, next update should be sent immediately
if !debouncer.ProcessUpdate(update2) {
t.Error("Update after true quiet period should be sent immediately")
}
}
func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate continuous high-frequency updates
for i := 0; i < 10; i++ {
update := &network_map.UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{
Serial: uint64(i),
},
},
MessageType: network_map.MessageTypeNetworkMap,
}
if i == 0 {
// First one sent immediately
if !debouncer.ProcessUpdate(update) {
t.Error("First update should be sent immediately")
}
} else {
// All others should be coalesced (not sent immediately)
if debouncer.ProcessUpdate(update) {
t.Errorf("Update %d should not be sent immediately", i)
}
}
// Wait a bit but send next update before debounce expires
time.Sleep(20 * time.Millisecond)
}
// Now wait for final debounce
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
if len(pendingUpdates) == 0 {
t.Fatal("Should have the last update pending")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 9 {
t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
netmapUpdate := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
MessageType: network_map.MessageTypeNetworkMap,
}
tokenUpdate1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
tokenUpdate2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
// First update sent immediately
debouncer.ProcessUpdate(netmapUpdate)
// Send multiple control config updates - they should all be queued
debouncer.ProcessUpdate(tokenUpdate1)
debouncer.ProcessUpdate(tokenUpdate2)
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get both control config updates
if len(pendingUpdates) != 2 {
t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates))
}
// Control configs should come first
if pendingUpdates[0] != tokenUpdate1 {
t.Error("First pending update should be tokenUpdate1")
}
if pendingUpdates[1] != tokenUpdate2 {
t.Error("Second pending update should be tokenUpdate2")
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
netmapUpdate1 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}},
MessageType: network_map.MessageTypeNetworkMap,
}
netmapUpdate2 := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}},
MessageType: network_map.MessageTypeNetworkMap,
}
tokenUpdate := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
// First update sent immediately
debouncer.ProcessUpdate(netmapUpdate1)
// Send token update and network map update
debouncer.ProcessUpdate(tokenUpdate)
debouncer.ProcessUpdate(netmapUpdate2)
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get 2 updates in order: token, then network map
if len(pendingUpdates) != 2 {
t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates))
}
// Token update should come first (preserves order)
if pendingUpdates[0] != tokenUpdate {
t.Error("First pending update should be tokenUpdate")
}
// Network map update should come second
if pendingUpdates[1] != netmapUpdate2 {
t.Error("Second pending update should be netmapUpdate2")
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}
func TestUpdateDebouncer_OrderPreservation(t *testing.T) {
debouncer := NewUpdateDebouncer(50 * time.Millisecond)
defer debouncer.Stop()
// Simulate: 50 network maps -> 1 control config -> 50 network maps
// Expected result: 3 messages (netmap, controlConfig, netmap)
// Send first network map immediately
firstNetmap := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}},
MessageType: network_map.MessageTypeNetworkMap,
}
if !debouncer.ProcessUpdate(firstNetmap) {
t.Error("First update should be sent immediately")
}
// Send 49 more network maps (will be coalesced to last one)
var lastNetmapBatch1 *network_map.UpdateMessage
for i := 1; i < 50; i++ {
lastNetmapBatch1 = &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
MessageType: network_map.MessageTypeNetworkMap,
}
debouncer.ProcessUpdate(lastNetmapBatch1)
}
// Send 1 control config
controlConfig := &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}},
MessageType: network_map.MessageTypeControlConfig,
}
debouncer.ProcessUpdate(controlConfig)
// Send 50 more network maps (will be coalesced to last one)
var lastNetmapBatch2 *network_map.UpdateMessage
for i := 50; i < 100; i++ {
lastNetmapBatch2 = &network_map.UpdateMessage{
Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}},
MessageType: network_map.MessageTypeNetworkMap,
}
debouncer.ProcessUpdate(lastNetmapBatch2)
}
// Wait for debounce period
select {
case <-debouncer.TimerChannel():
pendingUpdates := debouncer.GetPendingUpdates()
// Should get exactly 3 updates: netmap, controlConfig, netmap
if len(pendingUpdates) != 3 {
t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates))
}
// First should be the last netmap from batch 1
if pendingUpdates[0] != lastNetmapBatch1 {
t.Error("First pending update should be last netmap from batch 1")
}
if pendingUpdates[0].Update.NetworkMap.Serial != 49 {
t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial)
}
// Second should be the control config
if pendingUpdates[1] != controlConfig {
t.Error("Second pending update should be control config")
}
// Third should be the last netmap from batch 2
if pendingUpdates[2] != lastNetmapBatch2 {
t.Error("Third pending update should be last netmap from batch 2")
}
if pendingUpdates[2].Update.NetworkMap.Serial != 99 {
t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial)
}
case <-time.After(200 * time.Millisecond):
t.Error("Timer should have fired")
}
}

View File

@@ -1670,13 +1670,13 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
@@ -1684,8 +1684,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
return peer, netMap, postureChecks, dnsfwdPort, nil
}
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error {
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
return nil
}
if peer.Status.LastSeen.After(streamStartTime) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
return nil
}
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}

View File

@@ -58,7 +58,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error
@@ -114,8 +114,8 @@ type Manager interface {
UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)

View File

@@ -1881,7 +1881,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1952,7 +1952,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1961,6 +1961,82 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}
}
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
}, false)
require.NoError(t, err, "unable to add peer")
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err, "unable to get peer")
require.True(t, peer.Status.Connected, "peer should be connected")
streamStartTime := time.Now().UTC()
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.False(t, peer.Status.Connected, "peer should be disconnected")
})
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected,
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
})
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should still be connected")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
})
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -1983,7 +2059,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
@@ -3176,7 +3252,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1})
_, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, time.Now().UTC())
assert.NoError(b, err)
}

View File

@@ -9,10 +9,11 @@ import (
"time"
"github.com/gorilla/mux"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/zones"
@@ -137,7 +138,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
}
accounts.AddEndpoints(accountManager, settingsManager, router)
peers.AddEndpoints(accountManager, router, networkMapController)
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
users.AddEndpoints(accountManager, router)
users.AddInvitesEndpoints(accountManager, router)
users.AddPublicInvitesEndpoints(accountManager, router)

View File

@@ -17,6 +17,9 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -26,11 +29,12 @@ import (
// Handler is a handler that returns peers of the account
type Handler struct {
accountManager account.Manager
permissionsManager permissions.Manager
networkMapController network_map.Controller
}
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) {
peersHandler := NewHandler(accountManager, networkMapController)
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
@@ -42,10 +46,11 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMap
}
// NewHandler creates a new peers Handler
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler {
func NewHandler(accountManager account.Manager, networkMapController network_map.Controller, permissionsManager permissions.Manager) *Handler {
return &Handler{
accountManager: accountManager,
networkMapController: networkMapController,
permissionsManager: permissionsManager,
}
}
@@ -359,21 +364,30 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
// If the user is regular user and does not own the peer
// with the given peerID return an empty list
if !user.HasAdminPower() && !user.IsServiceUser && !userAuth.IsChild {
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read)
if err != nil {
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if !allowed && !userAuth.IsChild {
if account.Settings.RegularUsersViewBlocked {
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
return
}
peer, ok := account.Peers[peerID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)

View File

@@ -13,13 +13,17 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"go.uber.org/mock/gomock"
ugomock "go.uber.org/mock/gomock"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -102,7 +106,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
},
}
ctrl := gomock.NewController(t)
ctrl := ugomock.NewController(t)
networkMapController := network_map.NewMockController(ctrl)
networkMapController.EXPECT().
@@ -110,6 +114,20 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
Return("domain").
AnyTimes()
ctrl2 := gomock.NewController(t)
permissionsManager := permissions.NewMockManager(ctrl2)
permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
permissionsManager.EXPECT().
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)).
DoAndReturn(func(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error) {
user, ok := account.Users[userID]
if !ok {
return false, fmt.Errorf("user not found")
}
return user.HasAdminPower() || user.IsServiceUser, nil
}).
AnyTimes()
return &Handler{
accountManager: &mock_server.MockAccountManager{
UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) {
@@ -199,6 +217,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
},
},
networkMapController: networkMapController,
permissionsManager: permissionsManager,
}
}
@@ -376,12 +395,11 @@ func TestGetAccessiblePeers(t *testing.T) {
UserID: regularUser,
}
p := initTestMetaData(t, peer1, peer2, peer3)
tt := []struct {
name string
peerID string
callerUserID string
viewBlocked bool
expectedStatus int
expectedPeers []string
}{
@@ -420,10 +438,56 @@ func TestGetAccessiblePeers(t *testing.T) {
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
{
name: "regular user gets empty for owned peer list when view blocked",
peerID: "peer1",
callerUserID: regularUser,
viewBlocked: true,
expectedStatus: http.StatusOK,
expectedPeers: []string{},
},
{
name: "regular user gets empty list for unowned peer when view blocked",
peerID: "peer2",
callerUserID: regularUser,
viewBlocked: true,
expectedStatus: http.StatusOK,
expectedPeers: []string{},
},
{
name: "admin user still sees accessible peers when view blocked",
peerID: "peer2",
callerUserID: adminUser,
viewBlocked: true,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer3"},
},
{
name: "service user still sees accessible peers when view blocked",
peerID: "peer3",
callerUserID: serviceUser,
viewBlocked: true,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
p := initTestMetaData(t, peer1, peer2, peer3)
if tc.viewBlocked {
mockAM := p.accountManager.(*mock_server.MockAccountManager)
originalGetAccountByIDFunc := mockAM.GetAccountByIDFunc
mockAM.GetAccountByIDFunc = func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
account, err := originalGetAccountByIDFunc(ctx, accountID, userID)
if err != nil {
return nil, err
}
account.Settings.RegularUsersViewBlocked = true
return account, nil
}
}
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)

View File

@@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"github.com/netbirdio/management-integrations/integrations"
serverauth "github.com/netbirdio/netbird/management/server/auth"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
@@ -130,8 +131,10 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = ok
if integrations.IsValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = true
}
}
// Email is now extracted in ToUserAuth (from claims or userinfo endpoint)
@@ -207,8 +210,10 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = ok
if integrations.IsValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = true
}
}
return nbcontext.SetUserAuthInRequest(r, userAuth), nil

View File

@@ -627,15 +627,14 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
},
},
{
name: "Valid PAT Token accesses child",
name: "PAT Token with account param ignored in public version",
path: "/test?account=xyz",
authHeader: "Token " + PAT,
expectedUserAuth: &nbauth.UserAuth{
AccountId: "xyz",
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
IsChild: true,
IsPAT: true,
},
},
@@ -652,15 +651,14 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
},
{
name: "Valid JWT Token with child",
name: "JWT Token with account param ignored in public version",
path: "/test?account=xyz",
authHeader: "Bearer " + JWT,
expectedUserAuth: &nbauth.UserAuth{
AccountId: "xyz",
AccountId: accountID,
UserId: userID,
Domain: testAccount.Domain,
DomainCategory: testAccount.DomainCategory,
IsChild: true,
},
},
}

View File

@@ -610,6 +610,7 @@ func TestSync10PeersGetUpdates(t *testing.T) {
initialPeers := 10
additionalPeers := 10
expectedPeerCount := initialPeers + additionalPeers - 1 // -1 because peer doesn't see itself
var peers []wgtypes.Key
for i := 0; i < initialPeers; i++ {
@@ -618,8 +619,19 @@ func TestSync10PeersGetUpdates(t *testing.T) {
peers = append(peers, key)
}
// Track the maximum peer count each peer has seen
type peerState struct {
mu sync.Mutex
maxPeerCount int
done bool
}
peerStates := make(map[string]*peerState)
for _, pk := range peers {
peerStates[pk.PublicKey().String()] = &peerState{}
}
var wg sync.WaitGroup
wg.Add(initialPeers + initialPeers*additionalPeers)
wg.Add(initialPeers) // One completion per initial peer
var syncClients []mgmtProto.ManagementService_SyncClient
for _, pk := range peers {
@@ -643,6 +655,9 @@ func TestSync10PeersGetUpdates(t *testing.T) {
syncClients = append(syncClients, s)
go func(pk wgtypes.Key, syncStream mgmtProto.ManagementService_SyncClient) {
pubKey := pk.PublicKey().String()
state := peerStates[pubKey]
for {
encMsg := &mgmtProto.EncryptedMessage{}
err := syncStream.RecvMsg(encMsg)
@@ -651,19 +666,28 @@ func TestSync10PeersGetUpdates(t *testing.T) {
}
decryptedBytes, decErr := encryption.Decrypt(encMsg.Body, ts.serverPubKey, pk)
if decErr != nil {
t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pk.PublicKey().String(), decErr)
t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pubKey, decErr)
return
}
resp := &mgmtProto.SyncResponse{}
umErr := pb.Unmarshal(decryptedBytes, resp)
if umErr != nil {
t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pk.PublicKey().String(), umErr)
t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pubKey, umErr)
return
}
// We only count if there's a new peer update
if len(resp.GetRemotePeers()) > 0 {
// Track the maximum peer count seen (due to debouncing, updates are coalesced)
peerCount := len(resp.GetRemotePeers())
state.mu.Lock()
if peerCount > state.maxPeerCount {
state.maxPeerCount = peerCount
}
// Signal completion when this peer has seen all expected peers
if !state.done && state.maxPeerCount >= expectedPeerCount {
state.done = true
wg.Done()
}
state.mu.Unlock()
}
}(pk, s)
}
@@ -677,7 +701,30 @@ func TestSync10PeersGetUpdates(t *testing.T) {
time.Sleep(time.Duration(n) * time.Millisecond)
}
wg.Wait()
// Wait for debouncer to flush final updates (debounce interval is 1000ms)
time.Sleep(1500 * time.Millisecond)
// Wait with timeout
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
// Success - all peers received expected peer count
case <-time.After(5 * time.Second):
// Timeout - report which peers didn't receive all updates
t.Error("Timeout waiting for all peers to receive updates")
for pubKey, state := range peerStates {
state.mu.Lock()
if state.maxPeerCount < expectedPeerCount {
t.Errorf("Peer %s only saw %d peers, expected %d", pubKey, state.maxPeerCount, expectedPeerCount)
}
state.mu.Unlock()
}
}
for _, sc := range syncClients {
err := sc.CloseSend()

View File

@@ -37,8 +37,8 @@ type MockAccountManager struct {
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
@@ -214,16 +214,15 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use
return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented")
}
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if am.SyncAndMarkPeerFunc != nil {
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP)
return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, syncTime)
}
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error {
// TODO implement me
panic("implement me")
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
return nil
}
func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {
@@ -323,9 +322,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}

View File

@@ -103,11 +103,13 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
}
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
// syncTime is used as the LastSeen timestamp and for stale request detection
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
var peer *nbpeer.Peer
var settings *types.Settings
var expired bool
var err error
var skipped bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
@@ -115,9 +117,19 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return err
}
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID)
if connected && !syncTime.After(peer.Status.LastSeen) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect",
peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339))
skipped = true
return nil
}
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime)
return err
})
if skipped {
return nil
}
if err != nil {
return err
}
@@ -147,10 +159,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return nil
}
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) {
oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = time.Now().UTC()
newStatus.LastSeen = syncTime
newStatus.Connected = connected
// whenever peer got connected that means that it logged in successfully
if newStatus.Connected {

View File

@@ -2643,7 +2643,7 @@ func getGormConfig() *gorm.Config {
// newPostgresStore initializes a new Postgres store.
func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics, skipMigration bool) (Store, error) {
dsn, ok := os.LookupEnv(postgresDsnEnv)
dsn, ok := lookupDSNEnv(postgresDsnEnv, postgresDsnEnvLegacy)
if !ok {
return nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
@@ -2652,7 +2652,7 @@ func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics, skipMig
// newMysqlStore initializes a new MySQL store.
func newMysqlStore(ctx context.Context, metrics telemetry.AppMetrics, skipMigration bool) (Store, error) {
dsn, ok := os.LookupEnv(mysqlDsnEnv)
dsn, ok := lookupDSNEnv(mysqlDsnEnv, mysqlDsnEnvLegacy)
if !ok {
return nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
}

View File

@@ -243,10 +243,20 @@ type Store interface {
}
const (
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
postgresDsnEnv = "NB_STORE_ENGINE_POSTGRES_DSN"
postgresDsnEnvLegacy = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mysqlDsnEnv = "NB_STORE_ENGINE_MYSQL_DSN"
mysqlDsnEnvLegacy = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
)
// lookupDSNEnv checks the NB_ env var first, then falls back to the legacy NETBIRD_ env var.
func lookupDSNEnv(nbKey, legacyKey string) (string, bool) {
if v, ok := os.LookupEnv(nbKey); ok {
return v, true
}
return os.LookupEnv(legacyKey)
}
var supportedEngines = []types.Engine{types.SqliteStoreEngine, types.PostgresStoreEngine, types.MysqlStoreEngine}
func getStoreEngineFromEnv() types.Engine {
@@ -531,7 +541,7 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine)
}
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
dsn, ok := os.LookupEnv(postgresDsnEnv)
dsn, ok := lookupDSNEnv(postgresDsnEnv, postgresDsnEnvLegacy)
if !ok || dsn == "" {
var err error
_, dsn, err = testutil.CreatePostgresTestContainer()
@@ -569,7 +579,7 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng
}
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
dsn, ok := os.LookupEnv(mysqlDsnEnv)
dsn, ok := lookupDSNEnv(mysqlDsnEnv, mysqlDsnEnvLegacy)
if !ok || dsn == "" {
var err error
_, dsn, err = testutil.CreateMysqlTestContainer()

View File

@@ -122,6 +122,7 @@ type defaultAppMetrics struct {
Meter metric2.Meter
listener net.Listener
ctx context.Context
externallyManaged bool
idpMetrics *IDPMetrics
httpMiddleware *HTTPMiddleware
grpcMetrics *GRPCMetrics
@@ -171,6 +172,9 @@ func (appMetrics *defaultAppMetrics) Close() error {
// Expose metrics on a given port and endpoint. If endpoint is empty a defaultEndpoint one will be used.
// Exposes metrics in the Prometheus format https://prometheus.io/
func (appMetrics *defaultAppMetrics) Expose(ctx context.Context, port int, endpoint string) error {
if appMetrics.externallyManaged {
return nil
}
if endpoint == "" {
endpoint = defaultEndpoint
}
@@ -252,3 +256,49 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
accountManagerMetrics: accountManagerMetrics,
}, nil
}
// NewAppMetricsWithMeter creates AppMetrics using an externally provided meter.
// The caller is responsible for exposing metrics via HTTP. Expose() and Close() are no-ops.
func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetrics, error) {
idpMetrics, err := NewIDPMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize IDP metrics: %w", err)
}
middleware, err := NewMetricsMiddleware(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize HTTP middleware metrics: %w", err)
}
grpcMetrics, err := NewGRPCMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize gRPC metrics: %w", err)
}
storeMetrics, err := NewStoreMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize store metrics: %w", err)
}
updateChannelMetrics, err := NewUpdateChannelMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize update channel metrics: %w", err)
}
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
}
return &defaultAppMetrics{
Meter: meter,
ctx: ctx,
externallyManaged: true,
idpMetrics: idpMetrics,
httpMiddleware: middleware,
grpcMetrics: grpcMetrics,
storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics,
accountManagerMetrics: accountManagerMetrics,
}, nil
}

View File

@@ -21,8 +21,8 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/stun"
"github.com/netbirdio/netbird/util"
)

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"crypto/tls"
"net"
"net/url"
"sync"
@@ -134,3 +135,10 @@ func (r *Server) ListenerProtocols() []protocol.Protocol {
func (r *Server) InstanceURL() url.URL {
return r.relay.InstanceURL()
}
// RelayAccept returns the relay's Accept function for handling incoming connections.
// This allows external HTTP handlers to route connections to the relay without
// starting the relay's own listeners.
func (r *Server) RelayAccept() func(conn net.Conn) {
return r.relay.Accept
}

View File

@@ -0,0 +1,82 @@
package rest
import (
"context"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// BillingAPI APIs for billing and invoices
type BillingAPI struct {
c *Client
}
// GetUsage retrieves current usage statistics for the account
// See more: https://docs.netbird.io/api/resources/billing#get-current-usage
func (a *BillingAPI) GetUsage(ctx context.Context) (*api.UsageStats, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/usage", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.UsageStats](resp)
return &ret, err
}
// GetSubscription retrieves the current subscription details
// See more: https://docs.netbird.io/api/resources/billing#get-current-subscription
func (a *BillingAPI) GetSubscription(ctx context.Context) (*api.Subscription, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/subscription", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Subscription](resp)
return &ret, err
}
// GetInvoices retrieves the account's paid invoices
// See more: https://docs.netbird.io/api/resources/billing#list-all-invoices
func (a *BillingAPI) GetInvoices(ctx context.Context) ([]api.InvoiceResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/invoices", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.InvoiceResponse](resp)
return ret, err
}
// GetInvoicePDF retrieves the invoice PDF URL
// See more: https://docs.netbird.io/api/resources/billing#get-invoice-pdf
func (a *BillingAPI) GetInvoicePDF(ctx context.Context, invoiceID string) (*api.InvoicePDFResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/invoices/"+invoiceID+"/pdf", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.InvoicePDFResponse](resp)
return &ret, err
}
// GetInvoiceCSV retrieves the invoice CSV content
// See more: https://docs.netbird.io/api/resources/billing#get-invoice-csv
func (a *BillingAPI) GetInvoiceCSV(ctx context.Context, invoiceID string) (string, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/invoices/"+invoiceID+"/csv", nil, nil)
if err != nil {
return "", err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[string](resp)
return ret, err
}

View File

@@ -0,0 +1,194 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testUsageStats = api.UsageStats{
ActiveUsers: 15,
TotalUsers: 20,
ActivePeers: 10,
TotalPeers: 25,
}
testSubscription = api.Subscription{
Active: true,
PlanTier: "basic",
PriceId: "price_1HhxOp",
Currency: "USD",
Price: 1000,
Provider: "stripe",
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
testInvoice = api.InvoiceResponse{
Id: "inv_123",
PeriodStart: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
PeriodEnd: time.Date(2024, 2, 1, 0, 0, 0, 0, time.UTC),
Type: "invoice",
}
testInvoicePDF = api.InvoicePDFResponse{
Url: "https://example.com/invoice.pdf",
}
)
func TestBilling_GetUsage_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/usage", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testUsageStats)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetUsage(context.Background())
require.NoError(t, err)
assert.Equal(t, testUsageStats, *ret)
})
}
func TestBilling_GetUsage_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/usage", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetUsage(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestBilling_GetSubscription_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/subscription", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testSubscription)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetSubscription(context.Background())
require.NoError(t, err)
assert.Equal(t, testSubscription, *ret)
})
}
func TestBilling_GetSubscription_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/subscription", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetSubscription(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestBilling_GetInvoices_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/invoices", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.InvoiceResponse{testInvoice})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetInvoices(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testInvoice, ret[0])
})
}
func TestBilling_GetInvoices_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/invoices", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetInvoices(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestBilling_GetInvoicePDF_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/invoices/inv_123/pdf", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testInvoicePDF)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetInvoicePDF(context.Background(), "inv_123")
require.NoError(t, err)
assert.Equal(t, testInvoicePDF, *ret)
})
}
func TestBilling_GetInvoicePDF_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/invoices/inv_123/pdf", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetInvoicePDF(context.Background(), "inv_123")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestBilling_GetInvoiceCSV_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/invoices/inv_123/csv", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal("col1,col2\nval1,val2")
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetInvoiceCSV(context.Background(), "inv_123")
require.NoError(t, err)
assert.Equal(t, "col1,col2\nval1,val2", ret)
})
}
func TestBilling_GetInvoiceCSV_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/invoices/inv_123/csv", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetInvoiceCSV(context.Background(), "inv_123")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Empty(t, ret)
})
}

View File

@@ -73,6 +73,38 @@ type Client struct {
// Events NetBird Events APIs
// see more: https://docs.netbird.io/api/resources/events
Events *EventsAPI
// Billing NetBird Billing APIs for subscriptions, plans, and invoices
// see more: https://docs.netbird.io/api/resources/billing
Billing *BillingAPI
// MSP NetBird MSP tenant management APIs
// see more: https://docs.netbird.io/api/resources/msp
MSP *MSPAPI
// EDR NetBird EDR integration APIs (Intune, SentinelOne, Falcon, Huntress)
// see more: https://docs.netbird.io/api/resources/edr
EDR *EDRAPI
// SCIM NetBird SCIM IDP integration APIs
// see more: https://docs.netbird.io/api/resources/scim
SCIM *SCIMAPI
// EventStreaming NetBird Event Streaming integration APIs
// see more: https://docs.netbird.io/api/resources/event-streaming
EventStreaming *EventStreamingAPI
// IdentityProviders NetBird Identity Providers APIs
// see more: https://docs.netbird.io/api/resources/identity-providers
IdentityProviders *IdentityProvidersAPI
// Ingress NetBird Ingress Peers APIs
// see more: https://docs.netbird.io/api/resources/ingress-ports
Ingress *IngressAPI
// Instance NetBird Instance API
// see more: https://docs.netbird.io/api/resources/instance
Instance *InstanceAPI
}
// New initialize new Client instance using PAT token
@@ -120,6 +152,14 @@ func (c *Client) initialize() {
c.DNSZones = &DNSZonesAPI{c}
c.GeoLocation = &GeoLocationAPI{c}
c.Events = &EventsAPI{c}
c.Billing = &BillingAPI{c}
c.MSP = &MSPAPI{c}
c.EDR = &EDRAPI{c}
c.SCIM = &SCIMAPI{c}
c.EventStreaming = &EventStreamingAPI{c}
c.IdentityProviders = &IdentityProvidersAPI{c}
c.Ingress = &IngressAPI{c}
c.Instance = &InstanceAPI{c}
}
// NewRequest creates and executes new management API request

View File

@@ -0,0 +1,307 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// EDRAPI APIs for EDR integrations (Intune, SentinelOne, Falcon, Huntress)
type EDRAPI struct {
c *Client
}
// GetIntuneIntegration retrieves the EDR Intune integration
// See more: https://docs.netbird.io/api/resources/edr#get-intune-integration
func (a *EDRAPI) GetIntuneIntegration(ctx context.Context) (*api.EDRIntuneResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/intune", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRIntuneResponse](resp)
return &ret, err
}
// CreateIntuneIntegration creates a new EDR Intune integration
// See more: https://docs.netbird.io/api/resources/edr#create-intune-integration
func (a *EDRAPI) CreateIntuneIntegration(ctx context.Context, request api.EDRIntuneRequest) (*api.EDRIntuneResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/intune", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRIntuneResponse](resp)
return &ret, err
}
// UpdateIntuneIntegration updates an existing EDR Intune integration
// See more: https://docs.netbird.io/api/resources/edr#update-intune-integration
func (a *EDRAPI) UpdateIntuneIntegration(ctx context.Context, request api.EDRIntuneRequest) (*api.EDRIntuneResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/intune", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRIntuneResponse](resp)
return &ret, err
}
// DeleteIntuneIntegration deletes the EDR Intune integration
// See more: https://docs.netbird.io/api/resources/edr#delete-intune-integration
func (a *EDRAPI) DeleteIntuneIntegration(ctx context.Context) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/intune", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// GetSentinelOneIntegration retrieves the EDR SentinelOne integration
// See more: https://docs.netbird.io/api/resources/edr#get-sentinelone-integration
func (a *EDRAPI) GetSentinelOneIntegration(ctx context.Context) (*api.EDRSentinelOneResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/sentinelone", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRSentinelOneResponse](resp)
return &ret, err
}
// CreateSentinelOneIntegration creates a new EDR SentinelOne integration
// See more: https://docs.netbird.io/api/resources/edr#create-sentinelone-integration
func (a *EDRAPI) CreateSentinelOneIntegration(ctx context.Context, request api.EDRSentinelOneRequest) (*api.EDRSentinelOneResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/sentinelone", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRSentinelOneResponse](resp)
return &ret, err
}
// UpdateSentinelOneIntegration updates an existing EDR SentinelOne integration
// See more: https://docs.netbird.io/api/resources/edr#update-sentinelone-integration
func (a *EDRAPI) UpdateSentinelOneIntegration(ctx context.Context, request api.EDRSentinelOneRequest) (*api.EDRSentinelOneResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/sentinelone", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRSentinelOneResponse](resp)
return &ret, err
}
// DeleteSentinelOneIntegration deletes the EDR SentinelOne integration
// See more: https://docs.netbird.io/api/resources/edr#delete-sentinelone-integration
func (a *EDRAPI) DeleteSentinelOneIntegration(ctx context.Context) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/sentinelone", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// GetFalconIntegration retrieves the EDR Falcon integration
// See more: https://docs.netbird.io/api/resources/edr#get-falcon-integration
func (a *EDRAPI) GetFalconIntegration(ctx context.Context) (*api.EDRFalconResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/falcon", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRFalconResponse](resp)
return &ret, err
}
// CreateFalconIntegration creates a new EDR Falcon integration
// See more: https://docs.netbird.io/api/resources/edr#create-falcon-integration
func (a *EDRAPI) CreateFalconIntegration(ctx context.Context, request api.EDRFalconRequest) (*api.EDRFalconResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/falcon", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRFalconResponse](resp)
return &ret, err
}
// UpdateFalconIntegration updates an existing EDR Falcon integration
// See more: https://docs.netbird.io/api/resources/edr#update-falcon-integration
func (a *EDRAPI) UpdateFalconIntegration(ctx context.Context, request api.EDRFalconRequest) (*api.EDRFalconResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/falcon", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRFalconResponse](resp)
return &ret, err
}
// DeleteFalconIntegration deletes the EDR Falcon integration
// See more: https://docs.netbird.io/api/resources/edr#delete-falcon-integration
func (a *EDRAPI) DeleteFalconIntegration(ctx context.Context) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/falcon", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// GetHuntressIntegration retrieves the EDR Huntress integration
// See more: https://docs.netbird.io/api/resources/edr#get-huntress-integration
func (a *EDRAPI) GetHuntressIntegration(ctx context.Context) (*api.EDRHuntressResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/huntress", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRHuntressResponse](resp)
return &ret, err
}
// CreateHuntressIntegration creates a new EDR Huntress integration
// See more: https://docs.netbird.io/api/resources/edr#create-huntress-integration
func (a *EDRAPI) CreateHuntressIntegration(ctx context.Context, request api.EDRHuntressRequest) (*api.EDRHuntressResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/huntress", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRHuntressResponse](resp)
return &ret, err
}
// UpdateHuntressIntegration updates an existing EDR Huntress integration
// See more: https://docs.netbird.io/api/resources/edr#update-huntress-integration
func (a *EDRAPI) UpdateHuntressIntegration(ctx context.Context, request api.EDRHuntressRequest) (*api.EDRHuntressResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/huntress", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRHuntressResponse](resp)
return &ret, err
}
// DeleteHuntressIntegration deletes the EDR Huntress integration
// See more: https://docs.netbird.io/api/resources/edr#delete-huntress-integration
func (a *EDRAPI) DeleteHuntressIntegration(ctx context.Context) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/huntress", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// BypassPeerCompliance bypasses compliance for a non-compliant peer
// See more: https://docs.netbird.io/api/resources/edr#bypass-peer-compliance
func (a *EDRAPI) BypassPeerCompliance(ctx context.Context, peerID string) (*api.BypassResponse, error) {
resp, err := a.c.NewRequest(ctx, "POST", "/api/peers/"+peerID+"/edr/bypass", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.BypassResponse](resp)
return &ret, err
}
// RevokePeerBypass revokes the compliance bypass for a peer
// See more: https://docs.netbird.io/api/resources/edr#revoke-peer-bypass
func (a *EDRAPI) RevokePeerBypass(ctx context.Context, peerID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+peerID+"/edr/bypass", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// ListBypassedPeers returns all peers that have compliance bypassed
// See more: https://docs.netbird.io/api/resources/edr#list-all-bypassed-peers
func (a *EDRAPI) ListBypassedPeers(ctx context.Context) ([]api.BypassResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/edr/bypassed", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.BypassResponse](resp)
return ret, err
}

View File

@@ -0,0 +1,422 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testIntuneResponse = api.EDRIntuneResponse{
AccountId: "acc-1",
ClientId: "client-1",
TenantId: "tenant-1",
Enabled: true,
Id: 1,
Groups: []api.Group{},
LastSyncedInterval: 24,
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
CreatedBy: "user-1",
}
testSentinelOneResponse = api.EDRSentinelOneResponse{
AccountId: "acc-1",
ApiUrl: "https://sentinelone.example.com",
Enabled: true,
Id: 2,
Groups: []api.Group{},
LastSyncedInterval: 24,
MatchAttributes: api.SentinelOneMatchAttributes{},
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
CreatedBy: "user-1",
}
testFalconResponse = api.EDRFalconResponse{
AccountId: "acc-1",
CloudId: "us-1",
Enabled: true,
Id: 3,
Groups: []api.Group{},
ZtaScoreThreshold: 50,
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
CreatedBy: "user-1",
}
testHuntressResponse = api.EDRHuntressResponse{
AccountId: "acc-1",
Enabled: true,
Id: 4,
Groups: []api.Group{},
LastSyncedInterval: 24,
MatchAttributes: api.HuntressMatchAttributes{},
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
CreatedBy: "user-1",
}
testBypassResponse = api.BypassResponse{
PeerId: "peer-1",
}
)
// Intune tests
func TestEDR_GetIntuneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testIntuneResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.GetIntuneIntegration(context.Background())
require.NoError(t, err)
assert.Equal(t, testIntuneResponse, *ret)
})
}
func TestEDR_GetIntuneIntegration_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.GetIntuneIntegration(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestEDR_CreateIntuneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.EDRIntuneRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "client-1", req.ClientId)
retBytes, _ := json.Marshal(testIntuneResponse)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.CreateIntuneIntegration(context.Background(), api.EDRIntuneRequest{
ClientId: "client-1",
Secret: "secret",
TenantId: "tenant-1",
Groups: []string{"group-1"},
LastSyncedInterval: 24,
})
require.NoError(t, err)
assert.Equal(t, testIntuneResponse, *ret)
})
}
func TestEDR_CreateIntuneIntegration_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.CreateIntuneIntegration(context.Background(), api.EDRIntuneRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestEDR_UpdateIntuneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
retBytes, _ := json.Marshal(testIntuneResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.UpdateIntuneIntegration(context.Background(), api.EDRIntuneRequest{
ClientId: "client-1",
Secret: "new-secret",
TenantId: "tenant-1",
Groups: []string{"group-1"},
})
require.NoError(t, err)
assert.Equal(t, testIntuneResponse, *ret)
})
}
func TestEDR_DeleteIntuneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.EDR.DeleteIntuneIntegration(context.Background())
require.NoError(t, err)
})
}
func TestEDR_DeleteIntuneIntegration_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.EDR.DeleteIntuneIntegration(context.Background())
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
// SentinelOne tests
func TestEDR_GetSentinelOneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/sentinelone", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testSentinelOneResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.GetSentinelOneIntegration(context.Background())
require.NoError(t, err)
assert.Equal(t, testSentinelOneResponse, *ret)
})
}
func TestEDR_CreateSentinelOneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/sentinelone", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testSentinelOneResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.CreateSentinelOneIntegration(context.Background(), api.EDRSentinelOneRequest{
ApiToken: "token",
ApiUrl: "https://sentinelone.example.com",
Groups: []string{"group-1"},
LastSyncedInterval: 24,
MatchAttributes: api.SentinelOneMatchAttributes{},
})
require.NoError(t, err)
assert.Equal(t, testSentinelOneResponse, *ret)
})
}
func TestEDR_DeleteSentinelOneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/sentinelone", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.EDR.DeleteSentinelOneIntegration(context.Background())
require.NoError(t, err)
})
}
// Falcon tests
func TestEDR_GetFalconIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/falcon", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testFalconResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.GetFalconIntegration(context.Background())
require.NoError(t, err)
assert.Equal(t, testFalconResponse, *ret)
})
}
func TestEDR_CreateFalconIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/falcon", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testFalconResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.CreateFalconIntegration(context.Background(), api.EDRFalconRequest{
ClientId: "client-1",
Secret: "secret",
CloudId: "us-1",
Groups: []string{"group-1"},
ZtaScoreThreshold: 50,
})
require.NoError(t, err)
assert.Equal(t, testFalconResponse, *ret)
})
}
func TestEDR_DeleteFalconIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/falcon", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.EDR.DeleteFalconIntegration(context.Background())
require.NoError(t, err)
})
}
// Huntress tests
func TestEDR_GetHuntressIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/huntress", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testHuntressResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.GetHuntressIntegration(context.Background())
require.NoError(t, err)
assert.Equal(t, testHuntressResponse, *ret)
})
}
func TestEDR_CreateHuntressIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/huntress", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testHuntressResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.CreateHuntressIntegration(context.Background(), api.EDRHuntressRequest{
ApiKey: "key",
ApiSecret: "secret",
Groups: []string{"group-1"},
LastSyncedInterval: 24,
MatchAttributes: api.HuntressMatchAttributes{},
})
require.NoError(t, err)
assert.Equal(t, testHuntressResponse, *ret)
})
}
func TestEDR_DeleteHuntressIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/huntress", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.EDR.DeleteHuntressIntegration(context.Background())
require.NoError(t, err)
})
}
// Peer bypass tests
func TestEDR_BypassPeerCompliance_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/peer-1/edr/bypass", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testBypassResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.BypassPeerCompliance(context.Background(), "peer-1")
require.NoError(t, err)
assert.Equal(t, testBypassResponse, *ret)
})
}
func TestEDR_BypassPeerCompliance_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/peer-1/edr/bypass", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Bad request", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.BypassPeerCompliance(context.Background(), "peer-1")
assert.Error(t, err)
assert.Equal(t, "Bad request", err.Error())
assert.Nil(t, ret)
})
}
func TestEDR_RevokePeerBypass_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/peer-1/edr/bypass", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.EDR.RevokePeerBypass(context.Background(), "peer-1")
require.NoError(t, err)
})
}
func TestEDR_RevokePeerBypass_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/peer-1/edr/bypass", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.EDR.RevokePeerBypass(context.Background(), "peer-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestEDR_ListBypassedPeers_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/edr/bypassed", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.BypassResponse{testBypassResponse})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.ListBypassedPeers(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testBypassResponse, ret[0])
})
}
func TestEDR_ListBypassedPeers_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/edr/bypassed", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.ListBypassedPeers(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}

View File

@@ -0,0 +1,92 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"strconv"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// EventStreamingAPI APIs for event streaming integrations
type EventStreamingAPI struct {
c *Client
}
// List retrieves all event streaming integrations
// See more: https://docs.netbird.io/api/resources/event-streaming#list-all-event-streaming-integrations
func (a *EventStreamingAPI) List(ctx context.Context) ([]api.IntegrationResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/event-streaming", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IntegrationResponse](resp)
return ret, err
}
// Get retrieves a specific event streaming integration by ID
// See more: https://docs.netbird.io/api/resources/event-streaming#retrieve-an-event-streaming-integration
func (a *EventStreamingAPI) Get(ctx context.Context, integrationID int) (*api.IntegrationResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/event-streaming/"+strconv.Itoa(integrationID), nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IntegrationResponse](resp)
return &ret, err
}
// Create creates a new event streaming integration
// See more: https://docs.netbird.io/api/resources/event-streaming#create-an-event-streaming-integration
func (a *EventStreamingAPI) Create(ctx context.Context, request api.CreateIntegrationRequest) (*api.IntegrationResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/event-streaming", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IntegrationResponse](resp)
return &ret, err
}
// Update updates an existing event streaming integration
// See more: https://docs.netbird.io/api/resources/event-streaming#update-an-event-streaming-integration
func (a *EventStreamingAPI) Update(ctx context.Context, integrationID int, request api.CreateIntegrationRequest) (*api.IntegrationResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/event-streaming/"+strconv.Itoa(integrationID), bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IntegrationResponse](resp)
return &ret, err
}
// Delete deletes an event streaming integration
// See more: https://docs.netbird.io/api/resources/event-streaming#delete-an-event-streaming-integration
func (a *EventStreamingAPI) Delete(ctx context.Context, integrationID int) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/event-streaming/"+strconv.Itoa(integrationID), nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,194 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testIntegrationResponse = api.IntegrationResponse{
Id: ptr[int64](1),
AccountId: ptr("acc-1"),
Platform: (*api.IntegrationResponsePlatform)(ptr("datadog")),
Enabled: ptr(true),
Config: &map[string]string{"api_key": "****"},
CreatedAt: ptr(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)),
UpdatedAt: ptr(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)),
}
)
func TestEventStreaming_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.IntegrationResponse{testIntegrationResponse})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testIntegrationResponse, ret[0])
})
}
func TestEventStreaming_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestEventStreaming_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testIntegrationResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.Get(context.Background(), 1)
require.NoError(t, err)
assert.Equal(t, testIntegrationResponse, *ret)
})
}
func TestEventStreaming_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.Get(context.Background(), 1)
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestEventStreaming_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.CreateIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, api.CreateIntegrationRequestPlatformDatadog, req.Platform)
assert.Equal(t, true, req.Enabled)
retBytes, _ := json.Marshal(testIntegrationResponse)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.Create(context.Background(), api.CreateIntegrationRequest{
Platform: api.CreateIntegrationRequestPlatformDatadog,
Enabled: true,
Config: map[string]string{"api_key": "test-key"},
})
require.NoError(t, err)
assert.Equal(t, testIntegrationResponse, *ret)
})
}
func TestEventStreaming_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.Create(context.Background(), api.CreateIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestEventStreaming_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.CreateIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, false, req.Enabled)
retBytes, _ := json.Marshal(testIntegrationResponse)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.Update(context.Background(), 1, api.CreateIntegrationRequest{
Platform: api.CreateIntegrationRequestPlatformDatadog,
Enabled: false,
Config: map[string]string{"api_key": "updated-key"},
})
require.NoError(t, err)
assert.Equal(t, testIntegrationResponse, *ret)
})
}
func TestEventStreaming_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.Update(context.Background(), 1, api.CreateIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestEventStreaming_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.EventStreaming.Delete(context.Background(), 1)
require.NoError(t, err)
})
}
func TestEventStreaming_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.EventStreaming.Delete(context.Background(), 1)
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}

View File

@@ -2,6 +2,8 @@ package rest
import (
"context"
"fmt"
"time"
"github.com/netbirdio/netbird/shared/management/http/api"
)
@@ -11,10 +13,79 @@ type EventsAPI struct {
c *Client
}
// List list all events
// See more: https://docs.netbird.io/api/resources/events#list-all-events
func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil, nil)
// NetworkTrafficOption options for ListNetworkTrafficEvents API
type NetworkTrafficOption func(query map[string]string)
func NetworkTrafficPage(page int) NetworkTrafficOption {
return func(query map[string]string) {
query["page"] = fmt.Sprintf("%d", page)
}
}
func NetworkTrafficPageSize(pageSize int) NetworkTrafficOption {
return func(query map[string]string) {
query["page_size"] = fmt.Sprintf("%d", pageSize)
}
}
func NetworkTrafficUserID(userID string) NetworkTrafficOption {
return func(query map[string]string) {
query["user_id"] = userID
}
}
func NetworkTrafficReporterID(reporterID string) NetworkTrafficOption {
return func(query map[string]string) {
query["reporter_id"] = reporterID
}
}
func NetworkTrafficProtocol(protocol int) NetworkTrafficOption {
return func(query map[string]string) {
query["protocol"] = fmt.Sprintf("%d", protocol)
}
}
func NetworkTrafficType(t api.GetApiEventsNetworkTrafficParamsType) NetworkTrafficOption {
return func(query map[string]string) {
query["type"] = string(t)
}
}
func NetworkTrafficConnectionType(ct api.GetApiEventsNetworkTrafficParamsConnectionType) NetworkTrafficOption {
return func(query map[string]string) {
query["connection_type"] = string(ct)
}
}
func NetworkTrafficDirection(d api.GetApiEventsNetworkTrafficParamsDirection) NetworkTrafficOption {
return func(query map[string]string) {
query["direction"] = string(d)
}
}
func NetworkTrafficSearch(search string) NetworkTrafficOption {
return func(query map[string]string) {
query["search"] = search
}
}
func NetworkTrafficStartDate(t time.Time) NetworkTrafficOption {
return func(query map[string]string) {
query["start_date"] = t.Format(time.RFC3339)
}
}
func NetworkTrafficEndDate(t time.Time) NetworkTrafficOption {
return func(query map[string]string) {
query["end_date"] = t.Format(time.RFC3339)
}
}
// ListAuditEvents list all audit events
// See more: https://docs.netbird.io/api/resources/events#list-all-audit-events
func (a *EventsAPI) ListAuditEvents(ctx context.Context) ([]api.Event, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/events/audit", nil, nil)
if err != nil {
return nil, err
}
@@ -24,3 +95,21 @@ func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
ret, err := parseResponse[[]api.Event](resp)
return ret, err
}
// ListNetworkTrafficEvents list network traffic events
// See more: https://docs.netbird.io/api/resources/events#list-network-traffic-events
func (a *EventsAPI) ListNetworkTrafficEvents(ctx context.Context, opts ...NetworkTrafficOption) (*api.NetworkTrafficEventsResponse, error) {
query := make(map[string]string)
for _, o := range opts {
o(query)
}
resp, err := a.c.NewRequest(ctx, "GET", "/api/events/network-traffic", nil, query)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkTrafficEventsResponse](resp)
return &ret, err
}

View File

@@ -21,37 +21,76 @@ var (
Activity: "AccountCreate",
ActivityCode: api.EventActivityCodeAccountCreate,
}
testNetworkTrafficResponse = api.NetworkTrafficEventsResponse{
Data: []api.NetworkTrafficEvent{},
Page: 1,
PageSize: 50,
}
)
func TestEvents_List_200(t *testing.T) {
func TestEvents_ListAuditEvents_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/api/events/audit", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Event{testEvent})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Events.List(context.Background())
ret, err := c.Events.ListAuditEvents(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testEvent, ret[0])
})
}
func TestEvents_List_Err(t *testing.T) {
func TestEvents_ListAuditEvents_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/api/events/audit", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Events.List(context.Background())
ret, err := c.Events.ListAuditEvents(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestEvents_ListNetworkTrafficEvents_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events/network-traffic", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "1", r.URL.Query().Get("page"))
assert.Equal(t, "50", r.URL.Query().Get("page_size"))
retBytes, _ := json.Marshal(testNetworkTrafficResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Events.ListNetworkTrafficEvents(context.Background(),
rest.NetworkTrafficPage(1),
rest.NetworkTrafficPageSize(50),
)
require.NoError(t, err)
assert.Equal(t, testNetworkTrafficResponse, *ret)
})
}
func TestEvents_ListNetworkTrafficEvents_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events/network-traffic", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Events.ListNetworkTrafficEvents(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestEvents_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
// Do something that would trigger any event
@@ -62,7 +101,7 @@ func TestEvents_Integration(t *testing.T) {
})
require.NoError(t, err)
events, err := c.Events.List(context.Background())
events, err := c.Events.ListAuditEvents(context.Background())
require.NoError(t, err)
assert.NotEmpty(t, events)
})

View File

@@ -0,0 +1,92 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// IdentityProvidersAPI APIs for Identity Providers, do not use directly
type IdentityProvidersAPI struct {
c *Client
}
// List all identity providers
// See more: https://docs.netbird.io/api/resources/identity-providers#list-all-identity-providers
func (a *IdentityProvidersAPI) List(ctx context.Context) ([]api.IdentityProvider, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/identity-providers", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IdentityProvider](resp)
return ret, err
}
// Get identity provider info
// See more: https://docs.netbird.io/api/resources/identity-providers#retrieve-an-identity-provider
func (a *IdentityProvidersAPI) Get(ctx context.Context, idpID string) (*api.IdentityProvider, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/identity-providers/"+idpID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IdentityProvider](resp)
return &ret, err
}
// Create new identity provider
// See more: https://docs.netbird.io/api/resources/identity-providers#create-an-identity-provider
func (a *IdentityProvidersAPI) Create(ctx context.Context, request api.PostApiIdentityProvidersJSONRequestBody) (*api.IdentityProvider, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/identity-providers", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IdentityProvider](resp)
return &ret, err
}
// Update update identity provider
// See more: https://docs.netbird.io/api/resources/identity-providers#update-an-identity-provider
func (a *IdentityProvidersAPI) Update(ctx context.Context, idpID string, request api.PutApiIdentityProvidersIdpIdJSONRequestBody) (*api.IdentityProvider, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/identity-providers/"+idpID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IdentityProvider](resp)
return &ret, err
}
// Delete delete identity provider
// See more: https://docs.netbird.io/api/resources/identity-providers#delete-an-identity-provider
func (a *IdentityProvidersAPI) Delete(ctx context.Context, idpID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/identity-providers/"+idpID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,183 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var testIdentityProvider = api.IdentityProvider{
ClientId: "test-client-id",
Id: ptr("Test"),
}
func TestIdentityProviders_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.IdentityProvider{testIdentityProvider})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testIdentityProvider, ret[0])
})
}
func TestIdentityProviders_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestIdentityProviders_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testIdentityProvider)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testIdentityProvider, *ret)
})
}
func TestIdentityProviders_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestIdentityProviders_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiIdentityProvidersJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "new-client-id", req.ClientId)
retBytes, _ := json.Marshal(testIdentityProvider)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.Create(context.Background(), api.PostApiIdentityProvidersJSONRequestBody{
ClientId: "new-client-id",
})
require.NoError(t, err)
assert.Equal(t, testIdentityProvider, *ret)
})
}
func TestIdentityProviders_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.Create(context.Background(), api.PostApiIdentityProvidersJSONRequestBody{
ClientId: "new-client-id",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestIdentityProviders_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiIdentityProvidersIdpIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "updated-client-id", req.ClientId)
retBytes, _ := json.Marshal(testIdentityProvider)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.Update(context.Background(), "Test", api.PutApiIdentityProvidersIdpIdJSONRequestBody{
ClientId: "updated-client-id",
})
require.NoError(t, err)
assert.Equal(t, testIdentityProvider, *ret)
})
}
func TestIdentityProviders_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.Update(context.Background(), "Test", api.PutApiIdentityProvidersIdpIdJSONRequestBody{
ClientId: "updated-client-id",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestIdentityProviders_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.IdentityProviders.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestIdentityProviders_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.IdentityProviders.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}

View File

@@ -0,0 +1,92 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// IngressAPI APIs for Ingress Peers, do not use directly
type IngressAPI struct {
c *Client
}
// List all ingress peers
// See more: https://docs.netbird.io/api/resources/ingress#list-all-ingress-peers
func (a *IngressAPI) List(ctx context.Context) ([]api.IngressPeer, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/ingress/peers", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IngressPeer](resp)
return ret, err
}
// Get ingress peer info
// See more: https://docs.netbird.io/api/resources/ingress#retrieve-an-ingress-peer
func (a *IngressAPI) Get(ctx context.Context, ingressPeerID string) (*api.IngressPeer, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/ingress/peers/"+ingressPeerID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IngressPeer](resp)
return &ret, err
}
// Create new ingress peer
// See more: https://docs.netbird.io/api/resources/ingress#create-an-ingress-peer
func (a *IngressAPI) Create(ctx context.Context, request api.PostApiIngressPeersJSONRequestBody) (*api.IngressPeer, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/ingress/peers", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IngressPeer](resp)
return &ret, err
}
// Update update ingress peer
// See more: https://docs.netbird.io/api/resources/ingress#update-an-ingress-peer
func (a *IngressAPI) Update(ctx context.Context, ingressPeerID string, request api.PutApiIngressPeersIngressPeerIdJSONRequestBody) (*api.IngressPeer, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/ingress/peers/"+ingressPeerID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IngressPeer](resp)
return &ret, err
}
// Delete delete ingress peer
// See more: https://docs.netbird.io/api/resources/ingress#delete-an-ingress-peer
func (a *IngressAPI) Delete(ctx context.Context, ingressPeerID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/ingress/peers/"+ingressPeerID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,184 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var testIngressPeer = api.IngressPeer{
Connected: true,
Enabled: true,
Id: "Test",
}
func TestIngress_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.IngressPeer{testIngressPeer})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testIngressPeer, ret[0])
})
}
func TestIngress_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestIngress_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testIngressPeer)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testIngressPeer, *ret)
})
}
func TestIngress_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestIngress_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiIngressPeersJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "peer-id", req.PeerId)
retBytes, _ := json.Marshal(testIngressPeer)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.Create(context.Background(), api.PostApiIngressPeersJSONRequestBody{
PeerId: "peer-id",
})
require.NoError(t, err)
assert.Equal(t, testIngressPeer, *ret)
})
}
func TestIngress_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.Create(context.Background(), api.PostApiIngressPeersJSONRequestBody{
PeerId: "peer-id",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestIngress_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiIngressPeersIngressPeerIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, req.Enabled)
retBytes, _ := json.Marshal(testIngressPeer)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.Update(context.Background(), "Test", api.PutApiIngressPeersIngressPeerIdJSONRequestBody{
Enabled: true,
})
require.NoError(t, err)
assert.Equal(t, testIngressPeer, *ret)
})
}
func TestIngress_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.Update(context.Background(), "Test", api.PutApiIngressPeersIngressPeerIdJSONRequestBody{
Enabled: true,
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestIngress_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Ingress.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestIngress_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Ingress.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}

View File

@@ -0,0 +1,46 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// InstanceAPI APIs for Instance status and version, do not use directly
type InstanceAPI struct {
c *Client
}
// GetStatus get instance status
// See more: https://docs.netbird.io/api/resources/instance#get-instance-status
func (a *InstanceAPI) GetStatus(ctx context.Context) (*api.InstanceStatus, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/instance", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.InstanceStatus](resp)
return &ret, err
}
// Setup perform initial instance setup
// See more: https://docs.netbird.io/api/resources/instance#setup-instance
func (a *InstanceAPI) Setup(ctx context.Context, request api.PostApiSetupJSONRequestBody) (*api.SetupResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/setup", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.SetupResponse](resp)
return &ret, err
}

View File

@@ -0,0 +1,96 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testInstanceStatus = api.InstanceStatus{
SetupRequired: true,
}
testSetupResponse = api.SetupResponse{
Email: "admin@example.com",
UserId: "user-123",
}
)
func TestInstance_GetStatus_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/instance", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testInstanceStatus)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Instance.GetStatus(context.Background())
require.NoError(t, err)
assert.Equal(t, testInstanceStatus, *ret)
})
}
func TestInstance_GetStatus_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/instance", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Instance.GetStatus(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestInstance_Setup_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiSetupJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "admin@example.com", req.Email)
retBytes, _ := json.Marshal(testSetupResponse)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Instance.Setup(context.Background(), api.PostApiSetupJSONRequestBody{
Email: "admin@example.com",
})
require.NoError(t, err)
assert.Equal(t, testSetupResponse, *ret)
})
}
func TestInstance_Setup_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Instance.Setup(context.Background(), api.PostApiSetupJSONRequestBody{
Email: "admin@example.com",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}

View File

@@ -0,0 +1,122 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// MSPAPI APIs for MSP tenant management
type MSPAPI struct {
c *Client
}
// ListTenants retrieves all MSP tenants
// See more: https://docs.netbird.io/api/resources/msp#list-all-tenants
func (a *MSPAPI) ListTenants(ctx context.Context) (*api.GetTenantsResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/msp/tenants", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.GetTenantsResponse](resp)
return &ret, err
}
// CreateTenant creates a new MSP tenant
// See more: https://docs.netbird.io/api/resources/msp#create-a-tenant
func (a *MSPAPI) CreateTenant(ctx context.Context, request api.CreateTenantRequest) (*api.TenantResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/msp/tenants", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.TenantResponse](resp)
return &ret, err
}
// UpdateTenant updates an existing MSP tenant
// See more: https://docs.netbird.io/api/resources/msp#update-a-tenant
func (a *MSPAPI) UpdateTenant(ctx context.Context, tenantID string, request api.UpdateTenantRequest) (*api.TenantResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/msp/tenants/"+tenantID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.TenantResponse](resp)
return &ret, err
}
// DeleteTenant deletes an MSP tenant
// See more: https://docs.netbird.io/api/resources/msp#delete-a-tenant
func (a *MSPAPI) DeleteTenant(ctx context.Context, tenantID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/msp/tenants/"+tenantID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// UnlinkTenant unlinks a tenant from the MSP account
// See more: https://docs.netbird.io/api/resources/msp#unlink-a-tenant
func (a *MSPAPI) UnlinkTenant(ctx context.Context, tenantID, owner string) error {
params := map[string]string{"owner": owner}
requestBytes, err := json.Marshal(params)
if err != nil {
return err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/msp/tenants/"+tenantID+"/unlink", bytes.NewReader(requestBytes), nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// VerifyTenantDNS verifies a tenant domain DNS challenge
// See more: https://docs.netbird.io/api/resources/msp#verify-tenant-dns
func (a *MSPAPI) VerifyTenantDNS(ctx context.Context, tenantID string) error {
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/msp/tenants/"+tenantID+"/dns", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// InviteTenant invites an existing account as a tenant to the MSP account
// See more: https://docs.netbird.io/api/resources/msp#invite-a-tenant
func (a *MSPAPI) InviteTenant(ctx context.Context, tenantID string) (*api.TenantResponse, error) {
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/msp/tenants/"+tenantID+"/invite", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.TenantResponse](resp)
return &ret, err
}

View File

@@ -0,0 +1,251 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testTenant = api.TenantResponse{
Id: "tenant-1",
Name: "Test Tenant",
Domain: "test.example.com",
DnsChallenge: "challenge-123",
Status: "active",
Groups: []api.TenantGroupResponse{},
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
)
func TestMSP_ListTenants_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.TenantResponse{testTenant})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.ListTenants(context.Background())
require.NoError(t, err)
assert.Len(t, *ret, 1)
assert.Equal(t, testTenant, (*ret)[0])
})
}
func TestMSP_ListTenants_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.ListTenants(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestMSP_CreateTenant_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.CreateTenantRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "Test Tenant", req.Name)
assert.Equal(t, "test.example.com", req.Domain)
retBytes, _ := json.Marshal(testTenant)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.CreateTenant(context.Background(), api.CreateTenantRequest{
Name: "Test Tenant",
Domain: "test.example.com",
Groups: []api.TenantGroupResponse{},
})
require.NoError(t, err)
assert.Equal(t, testTenant, *ret)
})
}
func TestMSP_CreateTenant_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.CreateTenant(context.Background(), api.CreateTenantRequest{
Name: "Test Tenant",
Domain: "test.example.com",
Groups: []api.TenantGroupResponse{},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestMSP_UpdateTenant_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.UpdateTenantRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "Updated Tenant", req.Name)
retBytes, _ := json.Marshal(testTenant)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.UpdateTenant(context.Background(), "tenant-1", api.UpdateTenantRequest{
Name: "Updated Tenant",
Groups: []api.TenantGroupResponse{},
})
require.NoError(t, err)
assert.Equal(t, testTenant, *ret)
})
}
func TestMSP_UpdateTenant_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.UpdateTenant(context.Background(), "tenant-1", api.UpdateTenantRequest{
Name: "Updated Tenant",
Groups: []api.TenantGroupResponse{},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestMSP_DeleteTenant_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.MSP.DeleteTenant(context.Background(), "tenant-1")
require.NoError(t, err)
})
}
func TestMSP_DeleteTenant_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.MSP.DeleteTenant(context.Background(), "tenant-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestMSP_UnlinkTenant_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/unlink", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
w.WriteHeader(200)
})
err := c.MSP.UnlinkTenant(context.Background(), "tenant-1", "owner-1")
require.NoError(t, err)
})
}
func TestMSP_UnlinkTenant_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/unlink", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.MSP.UnlinkTenant(context.Background(), "tenant-1", "owner-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestMSP_VerifyTenantDNS_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/dns", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
w.WriteHeader(200)
})
err := c.MSP.VerifyTenantDNS(context.Background(), "tenant-1")
require.NoError(t, err)
})
}
func TestMSP_VerifyTenantDNS_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/dns", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Failed", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.MSP.VerifyTenantDNS(context.Background(), "tenant-1")
assert.Error(t, err)
assert.Equal(t, "Failed", err.Error())
})
}
func TestMSP_InviteTenant_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/invite", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testTenant)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.InviteTenant(context.Background(), "tenant-1")
require.NoError(t, err)
assert.Equal(t, testTenant, *ret)
})
}
func TestMSP_InviteTenant_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/invite", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.InviteTenant(context.Background(), "tenant-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}

View File

@@ -91,6 +91,20 @@ func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error {
return nil
}
// ListAllRouters list all routers across all networks
// See more: https://docs.netbird.io/api/resources/networks#list-all-network-routers
func (a *NetworksAPI) ListAllRouters(ctx context.Context) ([]api.NetworkRouter, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/routers", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.NetworkRouter](resp)
return ret, err
}
// NetworkResourcesAPI APIs for Network Resources, do not use directly
type NetworkResourcesAPI struct {
c *Client

View File

@@ -219,6 +219,35 @@ func TestNetworks_Integration(t *testing.T) {
})
}
func TestNetworks_ListAllRouters_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.NetworkRouter{testNetworkRouter})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.ListAllRouters(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testNetworkRouter, ret[0])
})
}
func TestNetworks_ListAllRouters_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.ListAllRouters(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworkResources_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {

View File

@@ -106,3 +106,173 @@ func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]ap
ret, err := parseResponse[[]api.Peer](resp)
return ret, err
}
// CreateTemporaryAccess create temporary access for a peer
// See more: https://docs.netbird.io/api/resources/peers#create-temporary-access
func (a *PeersAPI) CreateTemporaryAccess(ctx context.Context, peerID string, request api.PostApiPeersPeerIdTemporaryAccessJSONRequestBody) (*api.PeerTemporaryAccessResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/peers/"+peerID+"/temporary-access", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.PeerTemporaryAccessResponse](resp)
return &ret, err
}
// PeerIngressPortsAPI APIs for Peer Ingress Ports, do not use directly
type PeerIngressPortsAPI struct {
c *Client
peerID string
}
// IngressPorts APIs for peer ingress ports
func (a *PeersAPI) IngressPorts(peerID string) *PeerIngressPortsAPI {
return &PeerIngressPortsAPI{
c: a.c,
peerID: peerID,
}
}
// List list all ingress port allocations for a peer
// See more: https://docs.netbird.io/api/resources/peers#list-all-ingress-port-allocations
func (a *PeerIngressPortsAPI) List(ctx context.Context) ([]api.IngressPortAllocation, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+a.peerID+"/ingress/ports", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IngressPortAllocation](resp)
return ret, err
}
// Get get ingress port allocation info
// See more: https://docs.netbird.io/api/resources/peers#retrieve-an-ingress-port-allocation
func (a *PeerIngressPortsAPI) Get(ctx context.Context, allocationID string) (*api.IngressPortAllocation, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+a.peerID+"/ingress/ports/"+allocationID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IngressPortAllocation](resp)
return &ret, err
}
// Create create new ingress port allocation
// See more: https://docs.netbird.io/api/resources/peers#create-an-ingress-port-allocation
func (a *PeerIngressPortsAPI) Create(ctx context.Context, request api.PostApiPeersPeerIdIngressPortsJSONRequestBody) (*api.IngressPortAllocation, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/peers/"+a.peerID+"/ingress/ports", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IngressPortAllocation](resp)
return &ret, err
}
// Update update ingress port allocation
// See more: https://docs.netbird.io/api/resources/peers#update-an-ingress-port-allocation
func (a *PeerIngressPortsAPI) Update(ctx context.Context, allocationID string, request api.PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody) (*api.IngressPortAllocation, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/peers/"+a.peerID+"/ingress/ports/"+allocationID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IngressPortAllocation](resp)
return &ret, err
}
// Delete delete ingress port allocation
// See more: https://docs.netbird.io/api/resources/peers#delete-an-ingress-port-allocation
func (a *PeerIngressPortsAPI) Delete(ctx context.Context, allocationID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+a.peerID+"/ingress/ports/"+allocationID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// PeerJobsAPI APIs for Peer Jobs, do not use directly
type PeerJobsAPI struct {
c *Client
peerID string
}
// Jobs APIs for peer jobs
func (a *PeersAPI) Jobs(peerID string) *PeerJobsAPI {
return &PeerJobsAPI{
c: a.c,
peerID: peerID,
}
}
// List list all jobs for a peer
// See more: https://docs.netbird.io/api/resources/peers#list-all-peer-jobs
func (a *PeerJobsAPI) List(ctx context.Context) ([]api.JobResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+a.peerID+"/jobs", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.JobResponse](resp)
return ret, err
}
// Get get job info
// See more: https://docs.netbird.io/api/resources/peers#retrieve-a-peer-job
func (a *PeerJobsAPI) Get(ctx context.Context, jobID string) (*api.JobResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+a.peerID+"/jobs/"+jobID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.JobResponse](resp)
return &ret, err
}
// Create create new job for a peer
// See more: https://docs.netbird.io/api/resources/peers#create-a-peer-job
func (a *PeerJobsAPI) Create(ctx context.Context, request api.PostApiPeersPeerIdJobsJSONRequestBody) (*api.JobResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/peers/"+a.peerID+"/jobs", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.JobResponse](resp)
return &ret, err
}

View File

@@ -25,6 +25,21 @@ var (
DnsLabel: "test",
Id: "Test",
}
testPeerTemporaryAccess = api.PeerTemporaryAccessResponse{
Id: "Test",
Name: "test-peer",
}
testIngressPortAllocation = api.IngressPortAllocation{
Enabled: true,
Id: "alloc-1",
}
testJobResponse = api.JobResponse{
Id: "job-1",
Status: "pending",
}
)
func TestPeers_List_200(t *testing.T) {
@@ -177,6 +192,264 @@ func TestPeers_ListAccessiblePeers_Err(t *testing.T) {
})
}
func TestPeers_CreateTemporaryAccess_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/temporary-access", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testPeerTemporaryAccess)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.CreateTemporaryAccess(context.Background(), "Test", api.PostApiPeersPeerIdTemporaryAccessJSONRequestBody{})
require.NoError(t, err)
assert.Equal(t, testPeerTemporaryAccess, *ret)
})
}
func TestPeers_CreateTemporaryAccess_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/temporary-access", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.CreateTemporaryAccess(context.Background(), "Test", api.PostApiPeersPeerIdTemporaryAccessJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPeerIngressPorts_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.IngressPortAllocation{testIngressPortAllocation})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testIngressPortAllocation, ret[0])
})
}
func TestPeerIngressPorts_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeerIngressPorts_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports/alloc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testIngressPortAllocation)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").Get(context.Background(), "alloc-1")
require.NoError(t, err)
assert.Equal(t, testIngressPortAllocation, *ret)
})
}
func TestPeerIngressPorts_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports/alloc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").Get(context.Background(), "alloc-1")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeerIngressPorts_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testIngressPortAllocation)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").Create(context.Background(), api.PostApiPeersPeerIdIngressPortsJSONRequestBody{})
require.NoError(t, err)
assert.Equal(t, testIngressPortAllocation, *ret)
})
}
func TestPeerIngressPorts_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").Create(context.Background(), api.PostApiPeersPeerIdIngressPortsJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPeerIngressPorts_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports/alloc-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
retBytes, _ := json.Marshal(testIngressPortAllocation)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").Update(context.Background(), "alloc-1", api.PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody{})
require.NoError(t, err)
assert.Equal(t, testIngressPortAllocation, *ret)
})
}
func TestPeerIngressPorts_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports/alloc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").Update(context.Background(), "alloc-1", api.PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPeerIngressPorts_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports/alloc-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Peers.IngressPorts("Test").Delete(context.Background(), "alloc-1")
require.NoError(t, err)
})
}
func TestPeerIngressPorts_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports/alloc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Peers.IngressPorts("Test").Delete(context.Background(), "alloc-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestPeerJobs_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/jobs", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.JobResponse{testJobResponse})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Jobs("Test").List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testJobResponse.Id, ret[0].Id)
assert.Equal(t, testJobResponse.Status, ret[0].Status)
})
}
func TestPeerJobs_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/jobs", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Jobs("Test").List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeerJobs_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/jobs/job-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testJobResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Jobs("Test").Get(context.Background(), "job-1")
require.NoError(t, err)
assert.Equal(t, testJobResponse.Id, ret.Id)
assert.Equal(t, testJobResponse.Status, ret.Status)
})
}
func TestPeerJobs_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/jobs/job-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Jobs("Test").Get(context.Background(), "job-1")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeerJobs_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/jobs", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testJobResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Jobs("Test").Create(context.Background(), api.PostApiPeersPeerIdJobsJSONRequestBody{})
require.NoError(t, err)
assert.Equal(t, testJobResponse.Id, ret.Id)
assert.Equal(t, testJobResponse.Status, ret.Status)
})
}
func TestPeerJobs_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/jobs", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Jobs("Test").Create(context.Background(), api.PostApiPeersPeerIdJobsJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPeers_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
peers, err := c.Peers.List(context.Background())

View File

@@ -0,0 +1,119 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// SCIMAPI APIs for SCIM IDP integrations
type SCIMAPI struct {
c *Client
}
// List retrieves all SCIM IDP integrations
// See more: https://docs.netbird.io/api/resources/scim#list-all-scim-integrations
func (a *SCIMAPI) List(ctx context.Context) ([]api.ScimIntegration, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/scim-idp", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.ScimIntegration](resp)
return ret, err
}
// Get retrieves a specific SCIM IDP integration by ID
// See more: https://docs.netbird.io/api/resources/scim#retrieve-a-scim-integration
func (a *SCIMAPI) Get(ctx context.Context, integrationID string) (*api.ScimIntegration, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/scim-idp/"+integrationID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.ScimIntegration](resp)
return &ret, err
}
// Create creates a new SCIM IDP integration
// See more: https://docs.netbird.io/api/resources/scim#create-a-scim-integration
func (a *SCIMAPI) Create(ctx context.Context, request api.CreateScimIntegrationRequest) (*api.ScimIntegration, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/scim-idp", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.ScimIntegration](resp)
return &ret, err
}
// Update updates an existing SCIM IDP integration
// See more: https://docs.netbird.io/api/resources/scim#update-a-scim-integration
func (a *SCIMAPI) Update(ctx context.Context, integrationID string, request api.UpdateScimIntegrationRequest) (*api.ScimIntegration, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/scim-idp/"+integrationID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.ScimIntegration](resp)
return &ret, err
}
// Delete deletes a SCIM IDP integration
// See more: https://docs.netbird.io/api/resources/scim#delete-a-scim-integration
func (a *SCIMAPI) Delete(ctx context.Context, integrationID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/scim-idp/"+integrationID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// RegenerateToken regenerates the SCIM API token for an integration
// See more: https://docs.netbird.io/api/resources/scim#regenerate-scim-token
func (a *SCIMAPI) RegenerateToken(ctx context.Context, integrationID string) (*api.ScimTokenResponse, error) {
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/scim-idp/"+integrationID+"/token", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.ScimTokenResponse](resp)
return &ret, err
}
// GetLogs retrieves synchronization logs for an SCIM IDP integration
// See more: https://docs.netbird.io/api/resources/scim#get-scim-sync-logs
func (a *SCIMAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/scim-idp/"+integrationID+"/logs", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp)
return ret, err
}

View File

@@ -0,0 +1,262 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testScimIntegration = api.ScimIntegration{
Id: 1,
AuthToken: "****",
Enabled: true,
GroupPrefixes: []string{"eng-"},
UserGroupPrefixes: []string{"dev-"},
Provider: "okta",
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
testScimToken = api.ScimTokenResponse{
AuthToken: "new-token-123",
}
testSyncLog = api.IdpIntegrationSyncLog{
Id: 1,
Level: "info",
Message: "Sync completed",
Timestamp: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
)
func TestSCIM_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.ScimIntegration{testScimIntegration})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testScimIntegration, ret[0])
})
}
func TestSCIM_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestSCIM_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testScimIntegration)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.Get(context.Background(), "int-1")
require.NoError(t, err)
assert.Equal(t, testScimIntegration, *ret)
})
}
func TestSCIM_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.Get(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestSCIM_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.CreateScimIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "okta", req.Provider)
assert.Equal(t, "scim-", req.Prefix)
retBytes, _ := json.Marshal(testScimIntegration)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.Create(context.Background(), api.CreateScimIntegrationRequest{
Provider: "okta",
Prefix: "scim-",
GroupPrefixes: &[]string{"eng-"},
})
require.NoError(t, err)
assert.Equal(t, testScimIntegration, *ret)
})
}
func TestSCIM_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.Create(context.Background(), api.CreateScimIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestSCIM_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.UpdateScimIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, *req.Enabled)
retBytes, _ := json.Marshal(testScimIntegration)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.Update(context.Background(), "int-1", api.UpdateScimIntegrationRequest{
Enabled: ptr(true),
})
require.NoError(t, err)
assert.Equal(t, testScimIntegration, *ret)
})
}
func TestSCIM_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.Update(context.Background(), "int-1", api.UpdateScimIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestSCIM_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.SCIM.Delete(context.Background(), "int-1")
require.NoError(t, err)
})
}
func TestSCIM_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.SCIM.Delete(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestSCIM_RegenerateToken_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1/token", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testScimToken)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.RegenerateToken(context.Background(), "int-1")
require.NoError(t, err)
assert.Equal(t, testScimToken, *ret)
})
}
func TestSCIM_RegenerateToken_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1/token", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.RegenerateToken(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestSCIM_GetLogs_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.GetLogs(context.Background(), "int-1")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testSyncLog, ret[0])
})
}
func TestSCIM_GetLogs_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.GetLogs(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Empty(t, ret)
})
}

View File

@@ -105,3 +105,145 @@ func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) {
ret, err := parseResponse[api.User](resp)
return &ret, err
}
// ListInvites list all user invites
// See more: https://docs.netbird.io/api/resources/users#list-all-user-invites
func (a *UsersAPI) ListInvites(ctx context.Context) ([]api.UserInvite, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/invites", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.UserInvite](resp)
return ret, err
}
// CreateInvite create a user invite
// See more: https://docs.netbird.io/api/resources/users#create-a-user-invite
func (a *UsersAPI) CreateInvite(ctx context.Context, request api.PostApiUsersInvitesJSONRequestBody) (*api.UserInvite, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/invites", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.UserInvite](resp)
return &ret, err
}
// DeleteInvite delete a user invite
// See more: https://docs.netbird.io/api/resources/users#delete-a-user-invite
func (a *UsersAPI) DeleteInvite(ctx context.Context, inviteID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/invites/"+inviteID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// RegenerateInvite regenerate a user invite token
// See more: https://docs.netbird.io/api/resources/users#regenerate-a-user-invite
func (a *UsersAPI) RegenerateInvite(ctx context.Context, inviteID string, request api.PostApiUsersInvitesInviteIdRegenerateJSONRequestBody) (*api.UserInviteRegenerateResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/invites/"+inviteID+"/regenerate", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.UserInviteRegenerateResponse](resp)
return &ret, err
}
// GetInviteByToken get a user invite by token
// See more: https://docs.netbird.io/api/resources/users#get-a-user-invite-by-token
func (a *UsersAPI) GetInviteByToken(ctx context.Context, token string) (*api.UserInviteInfo, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/invites/"+token, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.UserInviteInfo](resp)
return &ret, err
}
// AcceptInvite accept a user invite
// See more: https://docs.netbird.io/api/resources/users#accept-a-user-invite
func (a *UsersAPI) AcceptInvite(ctx context.Context, token string, request api.PostApiUsersInvitesTokenAcceptJSONRequestBody) (*api.UserInviteAcceptResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/invites/"+token+"/accept", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.UserInviteAcceptResponse](resp)
return &ret, err
}
// Approve approve a pending user
// See more: https://docs.netbird.io/api/resources/users#approve-a-user
func (a *UsersAPI) Approve(ctx context.Context, userID string) (*api.User, error) {
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/approve", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.User](resp)
return &ret, err
}
// ChangePassword change a user's password
// See more: https://docs.netbird.io/api/resources/users#change-user-password
func (a *UsersAPI) ChangePassword(ctx context.Context, userID string, request api.PutApiUsersUserIdPasswordJSONRequestBody) error {
requestBytes, err := json.Marshal(request)
if err != nil {
return err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/users/"+userID+"/password", bytes.NewReader(requestBytes), nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// Reject reject a pending user
// See more: https://docs.netbird.io/api/resources/users#reject-a-user
func (a *UsersAPI) Reject(ctx context.Context, userID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID+"/reject", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -32,6 +32,23 @@ var (
Role: "user",
Status: api.UserStatusActive,
}
testUserInvite = api.UserInvite{
AutoGroups: []string{"group1"},
Id: "invite-1",
}
testUserInviteInfo = api.UserInviteInfo{
Email: "invite@test.com",
}
testUserInviteAcceptResponse = api.UserInviteAcceptResponse{
Success: true,
}
testUserInviteRegenerateResponse = api.UserInviteRegenerateResponse{
InviteToken: "new-token",
}
)
func TestUsers_List_200(t *testing.T) {
@@ -220,6 +237,269 @@ func TestUsers_Current_Err(t *testing.T) {
})
}
func TestUsers_ListInvites_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.UserInvite{testUserInvite})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.ListInvites(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testUserInvite, ret[0])
})
}
func TestUsers_ListInvites_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.ListInvites(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestUsers_CreateInvite_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiUsersInvitesJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "invite@test.com", req.Email)
retBytes, _ := json.Marshal(testUserInvite)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.CreateInvite(context.Background(), api.PostApiUsersInvitesJSONRequestBody{
Email: "invite@test.com",
})
require.NoError(t, err)
assert.Equal(t, testUserInvite, *ret)
})
}
func TestUsers_CreateInvite_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.CreateInvite(context.Background(), api.PostApiUsersInvitesJSONRequestBody{
Email: "invite@test.com",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestUsers_DeleteInvite_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/invite-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Users.DeleteInvite(context.Background(), "invite-1")
require.NoError(t, err)
})
}
func TestUsers_DeleteInvite_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/invite-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Users.DeleteInvite(context.Background(), "invite-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestUsers_RegenerateInvite_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/invite-1/regenerate", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testUserInviteRegenerateResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.RegenerateInvite(context.Background(), "invite-1", api.PostApiUsersInvitesInviteIdRegenerateJSONRequestBody{})
require.NoError(t, err)
assert.Equal(t, testUserInviteRegenerateResponse, *ret)
})
}
func TestUsers_RegenerateInvite_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/invite-1/regenerate", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.RegenerateInvite(context.Background(), "invite-1", api.PostApiUsersInvitesInviteIdRegenerateJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestUsers_GetInviteByToken_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/some-token", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testUserInviteInfo)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.GetInviteByToken(context.Background(), "some-token")
require.NoError(t, err)
assert.Equal(t, testUserInviteInfo, *ret)
})
}
func TestUsers_GetInviteByToken_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/some-token", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.GetInviteByToken(context.Background(), "some-token")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestUsers_AcceptInvite_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/some-token/accept", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testUserInviteAcceptResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.AcceptInvite(context.Background(), "some-token", api.PostApiUsersInvitesTokenAcceptJSONRequestBody{})
require.NoError(t, err)
assert.Equal(t, testUserInviteAcceptResponse, *ret)
})
}
func TestUsers_AcceptInvite_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/some-token/accept", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.AcceptInvite(context.Background(), "some-token", api.PostApiUsersInvitesTokenAcceptJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestUsers_Approve_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/approve", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testUser)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Approve(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testUser, *ret)
})
}
func TestUsers_Approve_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/approve", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Approve(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestUsers_ChangePassword_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/password", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiUsersUserIdPasswordJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
w.WriteHeader(200)
})
err := c.Users.ChangePassword(context.Background(), "Test", api.PutApiUsersUserIdPasswordJSONRequestBody{})
require.NoError(t, err)
})
}
func TestUsers_ChangePassword_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/password", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Users.ChangePassword(context.Background(), "Test", api.PutApiUsersUserIdPasswordJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
})
}
func TestUsers_Reject_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/reject", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Users.Reject(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestUsers_Reject_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/reject", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Users.Reject(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
})
}
func TestUsers_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
// rest client PAT is owner's

File diff suppressed because it is too large Load Diff

View File

@@ -16,6 +16,14 @@ const (
TokenAuthScopes = "TokenAuth.Scopes"
)
// Defines values for CreateIntegrationRequestPlatform.
const (
CreateIntegrationRequestPlatformDatadog CreateIntegrationRequestPlatform = "datadog"
CreateIntegrationRequestPlatformFirehose CreateIntegrationRequestPlatform = "firehose"
CreateIntegrationRequestPlatformGenericHttp CreateIntegrationRequestPlatform = "generic_http"
CreateIntegrationRequestPlatformS3 CreateIntegrationRequestPlatform = "s3"
)
// Defines values for DNSRecordType.
const (
DNSRecordTypeA DNSRecordType = "A"
@@ -188,6 +196,20 @@ const (
IngressPortAllocationRequestPortRangeProtocolUdp IngressPortAllocationRequestPortRangeProtocol = "udp"
)
// Defines values for IntegrationResponsePlatform.
const (
IntegrationResponsePlatformDatadog IntegrationResponsePlatform = "datadog"
IntegrationResponsePlatformFirehose IntegrationResponsePlatform = "firehose"
IntegrationResponsePlatformGenericHttp IntegrationResponsePlatform = "generic_http"
IntegrationResponsePlatformS3 IntegrationResponsePlatform = "s3"
)
// Defines values for InvoiceResponseType.
const (
InvoiceResponseTypeAccount InvoiceResponseType = "account"
InvoiceResponseTypeTenants InvoiceResponseType = "tenants"
)
// Defines values for JobResponseStatus.
const (
JobResponseStatusFailed JobResponseStatus = "failed"
@@ -266,6 +288,21 @@ const (
ResourceTypeSubnet ResourceType = "subnet"
)
// Defines values for SentinelOneMatchAttributesNetworkStatus.
const (
SentinelOneMatchAttributesNetworkStatusConnected SentinelOneMatchAttributesNetworkStatus = "connected"
SentinelOneMatchAttributesNetworkStatusDisconnected SentinelOneMatchAttributesNetworkStatus = "disconnected"
SentinelOneMatchAttributesNetworkStatusQuarantined SentinelOneMatchAttributesNetworkStatus = "quarantined"
)
// Defines values for TenantResponseStatus.
const (
TenantResponseStatusActive TenantResponseStatus = "active"
TenantResponseStatusExisting TenantResponseStatus = "existing"
TenantResponseStatusInvited TenantResponseStatus = "invited"
TenantResponseStatusPending TenantResponseStatus = "pending"
)
// Defines values for UserStatus.
const (
UserStatusActive UserStatus = "active"
@@ -299,6 +336,12 @@ const (
GetApiEventsNetworkTrafficParamsDirectionINGRESS GetApiEventsNetworkTrafficParamsDirection = "INGRESS"
)
// Defines values for PutApiIntegrationsMspTenantsIdInviteJSONBodyValue.
const (
PutApiIntegrationsMspTenantsIdInviteJSONBodyValueAccept PutApiIntegrationsMspTenantsIdInviteJSONBodyValue = "accept"
PutApiIntegrationsMspTenantsIdInviteJSONBodyValueDecline PutApiIntegrationsMspTenantsIdInviteJSONBodyValue = "decline"
)
// AccessiblePeer defines model for AccessiblePeer.
type AccessiblePeer struct {
// CityName Commonly used English name of the city
@@ -490,6 +533,21 @@ type BundleWorkloadResponse struct {
Type WorkloadType `json:"type"`
}
// BypassResponse Response for bypassed peer operations.
type BypassResponse struct {
// PeerId The ID of the bypassed peer.
PeerId string `json:"peer_id"`
}
// CheckoutResponse defines model for CheckoutResponse.
type CheckoutResponse struct {
// SessionId The unique identifier for the checkout session.
SessionId string `json:"session_id"`
// Url URL to redirect the user to the checkout session.
Url string `json:"url"`
}
// Checks List of objects that perform the actual checks
type Checks struct {
// GeoLocationCheck Posture check for geo location
@@ -532,6 +590,36 @@ type Country struct {
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
type CountryCode = string
// CreateIntegrationRequest Request payload for creating a new event streaming integration. Also used as the structure for the PUT request body, but not all fields are applicable for updates (see PUT operation description).
type CreateIntegrationRequest struct {
// Config Platform-specific configuration as key-value pairs. For creation, all necessary credentials and settings must be provided. For updates, provide the fields to change or the entire new configuration.
Config map[string]string `json:"config"`
// Enabled Specifies whether the integration is enabled. During creation (POST), this value is sent by the client, but the provided backend manager function `CreateIntegration` does not appear to use it directly, so its effect on creation should be verified. During updates (PUT), this field is used to enable or disable the integration.
Enabled bool `json:"enabled"`
// Platform The event streaming platform to integrate with (e.g., "datadog", "s3", "firehose"). This field is used for creation. For updates (PUT), this field, if sent, is ignored by the backend.
Platform CreateIntegrationRequestPlatform `json:"platform"`
}
// CreateIntegrationRequestPlatform The event streaming platform to integrate with (e.g., "datadog", "s3", "firehose"). This field is used for creation. For updates (PUT), this field, if sent, is ignored by the backend.
type CreateIntegrationRequestPlatform string
// CreateScimIntegrationRequest Request payload for creating an SCIM IDP integration
type CreateScimIntegrationRequest struct {
// GroupPrefixes List of start_with string patterns for groups to sync
GroupPrefixes *[]string `json:"group_prefixes,omitempty"`
// Prefix The connection prefix used for the SCIM provider
Prefix string `json:"prefix"`
// Provider Name of the SCIM identity provider
Provider string `json:"provider"`
// UserGroupPrefixes List of start_with string patterns for groups which users to sync
UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"`
}
// CreateSetupKeyRequest defines model for CreateSetupKeyRequest.
type CreateSetupKeyRequest struct {
// AllowExtraDnsLabels Allow extra DNS labels to be added to the peer
@@ -556,6 +644,24 @@ type CreateSetupKeyRequest struct {
UsageLimit int `json:"usage_limit"`
}
// CreateTenantRequest defines model for CreateTenantRequest.
type CreateTenantRequest struct {
// Domain The name for the MSP tenant
Domain string `json:"domain"`
// Groups MSP users Groups that can access the Tenant and Roles to assume
Groups []TenantGroupResponse `json:"groups"`
// Name The name for the MSP tenant
Name string `json:"name"`
}
// DNSChallengeResponse defines model for DNSChallengeResponse.
type DNSChallengeResponse struct {
// DnsChallenge The DNS challenge to set in a TXT record
DnsChallenge string `json:"dns_challenge"`
}
// DNSRecord defines model for DNSRecord.
type DNSRecord struct {
// Content DNS record content (IP address for A/AAAA, domain for CNAME)
@@ -598,6 +704,234 @@ type DNSSettings struct {
DisabledManagementGroups []string `json:"disabled_management_groups"`
}
// EDRFalconRequest Request payload for creating or updating a EDR Falcon integration
type EDRFalconRequest struct {
// ClientId CrowdStrike API client ID
ClientId string `json:"client_id"`
// CloudId CrowdStrike cloud identifier (e.g., "us-1", "us-2", "eu-1")
CloudId string `json:"cloud_id"`
// Enabled Indicates whether the integration is enabled
Enabled *bool `json:"enabled,omitempty"`
// Groups The Groups this integration applies to
Groups []string `json:"groups"`
// Secret CrowdStrike API client secret
Secret string `json:"secret"`
// ZtaScoreThreshold The minimum Zero Trust Assessment score required for agent approval (0-100)
ZtaScoreThreshold int `json:"zta_score_threshold"`
}
// EDRFalconResponse Represents a Falcon EDR integration
type EDRFalconResponse struct {
// AccountId The identifier of the account this integration belongs to.
AccountId string `json:"account_id"`
// CloudId CrowdStrike cloud identifier
CloudId string `json:"cloud_id"`
// CreatedAt Timestamp of when the integration was created.
CreatedAt time.Time `json:"created_at"`
// CreatedBy The user id that created the integration
CreatedBy string `json:"created_by"`
// Enabled Indicates whether the integration is enabled
Enabled bool `json:"enabled"`
// Groups List of groups
Groups []Group `json:"groups"`
// Id The unique numeric identifier for the integration.
Id int64 `json:"id"`
// LastSyncedAt Timestamp of when the integration was last synced.
LastSyncedAt time.Time `json:"last_synced_at"`
// UpdatedAt Timestamp of when the integration was last updated.
UpdatedAt time.Time `json:"updated_at"`
// ZtaScoreThreshold The minimum Zero Trust Assessment score required for agent approval (0-100)
ZtaScoreThreshold int `json:"zta_score_threshold"`
}
// EDRHuntressRequest Request payload for creating or updating a EDR Huntress integration
type EDRHuntressRequest struct {
// ApiKey Huntress API key
ApiKey string `json:"api_key"`
// ApiSecret Huntress API secret
ApiSecret string `json:"api_secret"`
// Enabled Indicates whether the integration is enabled
Enabled *bool `json:"enabled,omitempty"`
// Groups The Groups this integrations applies to
Groups []string `json:"groups"`
// LastSyncedInterval The devices last sync requirement interval in hours. Minimum value is 24 hours
LastSyncedInterval int `json:"last_synced_interval"`
// MatchAttributes Attribute conditions to match when approving agents
MatchAttributes HuntressMatchAttributes `json:"match_attributes"`
}
// EDRHuntressResponse Represents a Huntress EDR integration configuration
type EDRHuntressResponse struct {
// AccountId The identifier of the account this integration belongs to.
AccountId string `json:"account_id"`
// CreatedAt Timestamp of when the integration was created.
CreatedAt time.Time `json:"created_at"`
// CreatedBy The user id that created the integration
CreatedBy string `json:"created_by"`
// Enabled Indicates whether the integration is enabled
Enabled bool `json:"enabled"`
// Groups List of groups
Groups []Group `json:"groups"`
// Id The unique numeric identifier for the integration.
Id int64 `json:"id"`
// LastSyncedAt Timestamp of when the integration was last synced.
LastSyncedAt time.Time `json:"last_synced_at"`
// LastSyncedInterval The devices last sync requirement interval in hours.
LastSyncedInterval int `json:"last_synced_interval"`
// MatchAttributes Attribute conditions to match when approving agents
MatchAttributes HuntressMatchAttributes `json:"match_attributes"`
// UpdatedAt Timestamp of when the integration was last updated.
UpdatedAt time.Time `json:"updated_at"`
}
// EDRIntuneRequest Request payload for creating or updating a EDR Intune integration.
type EDRIntuneRequest struct {
// ClientId The Azure application client id
ClientId string `json:"client_id"`
// Enabled Indicates whether the integration is enabled
Enabled *bool `json:"enabled,omitempty"`
// Groups The Groups this integrations applies to
Groups []string `json:"groups"`
// LastSyncedInterval The devices last sync requirement interval in hours. Minimum value is 24 hours.
LastSyncedInterval int `json:"last_synced_interval"`
// Secret The Azure application client secret
Secret string `json:"secret"`
// TenantId The Azure tenant id
TenantId string `json:"tenant_id"`
}
// EDRIntuneResponse Represents a Intune EDR integration configuration
type EDRIntuneResponse struct {
// AccountId The identifier of the account this integration belongs to.
AccountId string `json:"account_id"`
// ClientId The Azure application client id
ClientId string `json:"client_id"`
// CreatedAt Timestamp of when the integration was created.
CreatedAt time.Time `json:"created_at"`
// CreatedBy The user id that created the integration
CreatedBy string `json:"created_by"`
// Enabled Indicates whether the integration is enabled
Enabled bool `json:"enabled"`
// Groups List of groups
Groups []Group `json:"groups"`
// Id The unique numeric identifier for the integration.
Id int64 `json:"id"`
// LastSyncedAt Timestamp of when the integration was last synced.
LastSyncedAt time.Time `json:"last_synced_at"`
// LastSyncedInterval The devices last sync requirement interval in hours.
LastSyncedInterval int `json:"last_synced_interval"`
// TenantId The Azure tenant id
TenantId string `json:"tenant_id"`
// UpdatedAt Timestamp of when the integration was last updated.
UpdatedAt time.Time `json:"updated_at"`
}
// EDRSentinelOneRequest Request payload for creating or updating a EDR SentinelOne integration
type EDRSentinelOneRequest struct {
// ApiToken SentinelOne API token
ApiToken string `json:"api_token"`
// ApiUrl The Base URL of SentinelOne API
ApiUrl string `json:"api_url"`
// Enabled Indicates whether the integration is enabled
Enabled *bool `json:"enabled,omitempty"`
// Groups The Groups this integrations applies to
Groups []string `json:"groups"`
// LastSyncedInterval The devices last sync requirement interval in hours. Minimum value is 24 hours.
LastSyncedInterval int `json:"last_synced_interval"`
// MatchAttributes Attribute conditions to match when approving agents
MatchAttributes SentinelOneMatchAttributes `json:"match_attributes"`
}
// EDRSentinelOneResponse Represents a SentinelOne EDR integration configuration
type EDRSentinelOneResponse struct {
// AccountId The identifier of the account this integration belongs to.
AccountId string `json:"account_id"`
// ApiUrl The Base URL of SentinelOne API
ApiUrl string `json:"api_url"`
// CreatedAt Timestamp of when the integration was created.
CreatedAt time.Time `json:"created_at"`
// CreatedBy The user id that created the integration
CreatedBy string `json:"created_by"`
// Enabled Indicates whether the integration is enabled
Enabled bool `json:"enabled"`
// Groups List of groups
Groups []Group `json:"groups"`
// Id The unique numeric identifier for the integration.
Id int64 `json:"id"`
// LastSyncedAt Timestamp of when the integration was last synced.
LastSyncedAt time.Time `json:"last_synced_at"`
// LastSyncedInterval The devices last sync requirement interval in hours.
LastSyncedInterval int `json:"last_synced_interval"`
// MatchAttributes Attribute conditions to match when approving agents
MatchAttributes SentinelOneMatchAttributes `json:"match_attributes"`
// UpdatedAt Timestamp of when the integration was last updated.
UpdatedAt time.Time `json:"updated_at"`
}
// ErrorResponse Standard error response. Note: The exact structure of this error response is inferred from `util.WriteErrorResponse` and `util.WriteError` usage in the provided Go code, as a specific Go struct for errors was not provided.
type ErrorResponse struct {
// Message A human-readable error message.
Message *string `json:"message,omitempty"`
}
// Event defines model for Event.
type Event struct {
// Activity The activity that occurred during the event
@@ -643,6 +977,9 @@ type GeoLocationCheck struct {
// GeoLocationCheckAction Action to take upon policy match
type GeoLocationCheckAction string
// GetTenantsResponse defines model for GetTenantsResponse.
type GetTenantsResponse = []TenantResponse
// Group defines model for Group.
type Group struct {
// Id Group ID
@@ -699,6 +1036,21 @@ type GroupRequest struct {
Resources *[]Resource `json:"resources,omitempty"`
}
// HuntressMatchAttributes Attribute conditions to match when approving agents
type HuntressMatchAttributes struct {
// DefenderPolicyStatus Policy status of Defender AV for Managed Antivirus.
DefenderPolicyStatus *string `json:"defender_policy_status,omitempty"`
// DefenderStatus Status of Defender AV Managed Antivirus.
DefenderStatus *string `json:"defender_status,omitempty"`
// DefenderSubstatus Sub-status of Defender AV Managed Antivirus.
DefenderSubstatus *string `json:"defender_substatus,omitempty"`
// FirewallStatus Status of agent firewall. Can be one of Disabled, Enabled, Pending Isolation, Isolated, Pending Release.
FirewallStatus *string `json:"firewall_status,omitempty"`
}
// IdentityProvider defines model for IdentityProvider.
type IdentityProvider struct {
// ClientId OAuth2 client ID
@@ -738,6 +1090,21 @@ type IdentityProviderRequest struct {
// IdentityProviderType Type of identity provider
type IdentityProviderType string
// IdpIntegrationSyncLog Represents a synchronization log entry for an integration
type IdpIntegrationSyncLog struct {
// Id The unique identifier for the sync log
Id int64 `json:"id"`
// Level The log level
Level string `json:"level"`
// Message Log message
Message string `json:"message"`
// Timestamp Timestamp of when the log was created
Timestamp time.Time `json:"timestamp"`
}
// IngressPeer defines model for IngressPeer.
type IngressPeer struct {
AvailablePorts AvailablePorts `json:"available_ports"`
@@ -892,6 +1259,57 @@ type InstanceVersionInfo struct {
ManagementUpdateAvailable bool `json:"management_update_available"`
}
// IntegrationResponse Represents an event streaming integration.
type IntegrationResponse struct {
// AccountId The identifier of the account this integration belongs to.
AccountId *string `json:"account_id,omitempty"`
// Config Configuration for the integration. Sensitive keys (like API keys, secret keys) are masked with '****' in responses, as indicated by the GetIntegration handler logic.
Config *map[string]string `json:"config,omitempty"`
// CreatedAt Timestamp of when the integration was created.
CreatedAt *time.Time `json:"created_at,omitempty"`
// Enabled Whether the integration is currently active.
Enabled *bool `json:"enabled,omitempty"`
// Id The unique numeric identifier for the integration.
Id *int64 `json:"id,omitempty"`
// Platform The event streaming platform.
Platform *IntegrationResponsePlatform `json:"platform,omitempty"`
// UpdatedAt Timestamp of when the integration was last updated.
UpdatedAt *time.Time `json:"updated_at,omitempty"`
}
// IntegrationResponsePlatform The event streaming platform.
type IntegrationResponsePlatform string
// InvoicePDFResponse defines model for InvoicePDFResponse.
type InvoicePDFResponse struct {
// Url URL to redirect the user to invoice.
Url string `json:"url"`
}
// InvoiceResponse defines model for InvoiceResponse.
type InvoiceResponse struct {
// Id The Stripe invoice id
Id string `json:"id"`
// PeriodEnd The end date of the invoice period.
PeriodEnd time.Time `json:"period_end"`
// PeriodStart The start date of the invoice period.
PeriodStart time.Time `json:"period_start"`
// Type The invoice type
Type InvoiceResponseType `json:"type"`
}
// InvoiceResponseType The invoice type
type InvoiceResponseType string
// JobRequest defines model for JobRequest.
type JobRequest struct {
Workload WorkloadRequest `json:"workload"`
@@ -1797,6 +2215,15 @@ type PolicyUpdate struct {
SourcePostureChecks *[]string `json:"source_posture_checks,omitempty"`
}
// PortalResponse defines model for PortalResponse.
type PortalResponse struct {
// SessionId The unique identifier for the customer portal session.
SessionId string `json:"session_id"`
// Url URL to redirect the user to the customer portal.
Url string `json:"url"`
}
// PostureCheck defines model for PostureCheck.
type PostureCheck struct {
// Checks List of objects that perform the actual checks
@@ -1824,6 +2251,21 @@ type PostureCheckUpdate struct {
Name string `json:"name"`
}
// Price defines model for Price.
type Price struct {
// Currency Currency code for this price.
Currency string `json:"currency"`
// Price Price amount in minor units (e.g., cents).
Price int `json:"price"`
// PriceId Unique identifier for the price.
PriceId string `json:"price_id"`
// Unit Unit of measurement for this price (e.g., per user).
Unit string `json:"unit"`
}
// Process Describes the operational activity within a peer's system.
type Process struct {
// LinuxPath Path to the process executable file in a Linux operating system
@@ -1841,6 +2283,24 @@ type ProcessCheck struct {
Processes []Process `json:"processes"`
}
// Product defines model for Product.
type Product struct {
// Description Detailed description of the product.
Description string `json:"description"`
// Features List of features provided by the product.
Features []string `json:"features"`
// Free Indicates whether the product is free or not.
Free bool `json:"free"`
// Name Name of the product.
Name string `json:"name"`
// Prices List of prices for the product in different currencies
Prices []Price `json:"prices"`
}
// Resource defines model for Resource.
type Resource struct {
// Id ID of the resource
@@ -1950,6 +2410,66 @@ type RulePortRange struct {
Start int `json:"start"`
}
// ScimIntegration Represents a SCIM IDP integration
type ScimIntegration struct {
// AuthToken SCIM API token (full on creation, masked otherwise)
AuthToken string `json:"auth_token"`
// Enabled Indicates whether the integration is enabled
Enabled bool `json:"enabled"`
// GroupPrefixes List of start_with string patterns for groups to sync
GroupPrefixes []string `json:"group_prefixes"`
// Id The unique identifier for the integration
Id int64 `json:"id"`
// LastSyncedAt Timestamp of when the integration was last synced
LastSyncedAt time.Time `json:"last_synced_at"`
// Provider Name of the SCIM identity provider
Provider string `json:"provider"`
// UserGroupPrefixes List of start_with string patterns for groups which users to sync
UserGroupPrefixes []string `json:"user_group_prefixes"`
}
// ScimTokenResponse Response containing the regenerated SCIM token
type ScimTokenResponse struct {
// AuthToken The newly generated SCIM API token
AuthToken string `json:"auth_token"`
}
// SentinelOneMatchAttributes Attribute conditions to match when approving agents
type SentinelOneMatchAttributes struct {
// ActiveThreats The maximum allowed number of active threats on the agent
ActiveThreats *int `json:"active_threats,omitempty"`
// EncryptedApplications Whether disk encryption is enabled on the agent
EncryptedApplications *bool `json:"encrypted_applications,omitempty"`
// FirewallEnabled Whether the agent firewall is enabled
FirewallEnabled *bool `json:"firewall_enabled,omitempty"`
// Infected Whether the agent is currently flagged as infected
Infected *bool `json:"infected,omitempty"`
// IsActive Whether the agent has been recently active and reporting
IsActive *bool `json:"is_active,omitempty"`
// IsUpToDate Whether the agent is running the latest available version
IsUpToDate *bool `json:"is_up_to_date,omitempty"`
// NetworkStatus The current network connectivity status of the device
NetworkStatus *SentinelOneMatchAttributesNetworkStatus `json:"network_status,omitempty"`
// OperationalState The current operational state of the agent
OperationalState *string `json:"operational_state,omitempty"`
}
// SentinelOneMatchAttributesNetworkStatus The current network connectivity status of the device
type SentinelOneMatchAttributesNetworkStatus string
// SetupKey defines model for SetupKey.
type SetupKey struct {
// AllowExtraDnsLabels Allow extra DNS labels to be added to the peer
@@ -2121,6 +2641,117 @@ type SetupResponse struct {
UserId string `json:"user_id"`
}
// Subscription defines model for Subscription.
type Subscription struct {
// Active Indicates whether the subscription is active or not.
Active bool `json:"active"`
// Currency Currency code of the subscription.
Currency string `json:"currency"`
// Features List of features included in the subscription.
Features *[]string `json:"features,omitempty"`
// PlanTier The tier of the plan for the subscription.
PlanTier string `json:"plan_tier"`
// Price Price amount in minor units (e.g., cents).
Price int `json:"price"`
// PriceId Unique identifier for the price of the subscription.
PriceId string `json:"price_id"`
// Provider The provider of the subscription.
Provider string `json:"provider"`
// RemainingTrial The remaining time for the trial period, in seconds.
RemainingTrial *int `json:"remaining_trial,omitempty"`
// UpdatedAt The date and time when the subscription was last updated.
UpdatedAt time.Time `json:"updated_at"`
}
// TenantGroupResponse defines model for TenantGroupResponse.
type TenantGroupResponse struct {
// Id The Group ID
Id string `json:"id"`
// Role The Role name
Role string `json:"role"`
}
// TenantResponse defines model for TenantResponse.
type TenantResponse struct {
// ActivatedAt The date and time when the tenant was activated.
ActivatedAt *time.Time `json:"activated_at,omitempty"`
// CreatedAt The date and time when the tenant was created.
CreatedAt time.Time `json:"created_at"`
// DnsChallenge The DNS challenge to set in a TXT record
DnsChallenge string `json:"dns_challenge"`
// Domain The tenant account domain
Domain string `json:"domain"`
// Groups MSP users Groups that can access the Tenant and Roles to assume
Groups []TenantGroupResponse `json:"groups"`
// Id The updated MSP tenant account ID
Id string `json:"id"`
// InvitedAt The date and time when the existing tenant was invited.
InvitedAt *time.Time `json:"invited_at,omitempty"`
// Name The name for the MSP tenant
Name string `json:"name"`
// Status The status of the tenant
Status TenantResponseStatus `json:"status"`
// UpdatedAt The date and time when the tenant was last updated.
UpdatedAt time.Time `json:"updated_at"`
}
// TenantResponseStatus The status of the tenant
type TenantResponseStatus string
// UpdateScimIntegrationRequest Request payload for updating an SCIM IDP integration
type UpdateScimIntegrationRequest struct {
// Enabled Indicates whether the integration is enabled
Enabled *bool `json:"enabled,omitempty"`
// GroupPrefixes List of start_with string patterns for groups to sync
GroupPrefixes *[]string `json:"group_prefixes,omitempty"`
// UserGroupPrefixes List of start_with string patterns for groups which users to sync
UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"`
}
// UpdateTenantRequest defines model for UpdateTenantRequest.
type UpdateTenantRequest struct {
// Groups MSP users Groups that can access the Tenant and Roles to assume
Groups []TenantGroupResponse `json:"groups"`
// Name The name for the MSP tenant
Name string `json:"name"`
}
// UsageStats defines model for UsageStats.
type UsageStats struct {
// ActivePeers Number of active peers.
ActivePeers int64 `json:"active_peers"`
// ActiveUsers Number of active users.
ActiveUsers int64 `json:"active_users"`
// TotalPeers Total number of peers.
TotalPeers int64 `json:"total_peers"`
// TotalUsers Total number of users.
TotalUsers int64 `json:"total_users"`
}
// User defines model for User.
type User struct {
// AutoGroups Group IDs to auto-assign to peers registered by this user
@@ -2407,6 +3038,66 @@ type GetApiGroupsParams struct {
Name *string `form:"name,omitempty" json:"name,omitempty"`
}
// PostApiIntegrationsBillingAwsMarketplaceActivateJSONBody defines parameters for PostApiIntegrationsBillingAwsMarketplaceActivate.
type PostApiIntegrationsBillingAwsMarketplaceActivateJSONBody struct {
// PlanTier The plan tier to activate the subscription for.
PlanTier string `json:"plan_tier"`
}
// PostApiIntegrationsBillingAwsMarketplaceEnrichJSONBody defines parameters for PostApiIntegrationsBillingAwsMarketplaceEnrich.
type PostApiIntegrationsBillingAwsMarketplaceEnrichJSONBody struct {
// AwsUserId The AWS user ID.
AwsUserId string `json:"aws_user_id"`
}
// PostApiIntegrationsBillingCheckoutJSONBody defines parameters for PostApiIntegrationsBillingCheckout.
type PostApiIntegrationsBillingCheckoutJSONBody struct {
// BaseURL The base URL for the redirect after checkout.
BaseURL string `json:"baseURL"`
// EnableTrial Enables a 14-day trial for the account.
EnableTrial *bool `json:"enableTrial,omitempty"`
// PriceID The Price ID for checkout.
PriceID string `json:"priceID"`
}
// GetApiIntegrationsBillingPortalParams defines parameters for GetApiIntegrationsBillingPortal.
type GetApiIntegrationsBillingPortalParams struct {
// BaseURL The base URL for the redirect after accessing the portal.
BaseURL string `form:"baseURL" json:"baseURL"`
}
// PutApiIntegrationsBillingSubscriptionJSONBody defines parameters for PutApiIntegrationsBillingSubscription.
type PutApiIntegrationsBillingSubscriptionJSONBody struct {
// PlanTier The plan tier to change the subscription to.
PlanTier *string `json:"plan_tier,omitempty"`
// PriceID The Price ID to change the subscription to.
PriceID *string `json:"priceID,omitempty"`
}
// PutApiIntegrationsMspTenantsIdInviteJSONBody defines parameters for PutApiIntegrationsMspTenantsIdInvite.
type PutApiIntegrationsMspTenantsIdInviteJSONBody struct {
// Value Accept or decline the invitation.
Value PutApiIntegrationsMspTenantsIdInviteJSONBodyValue `json:"value"`
}
// PutApiIntegrationsMspTenantsIdInviteJSONBodyValue defines parameters for PutApiIntegrationsMspTenantsIdInvite.
type PutApiIntegrationsMspTenantsIdInviteJSONBodyValue string
// PostApiIntegrationsMspTenantsIdSubscriptionJSONBody defines parameters for PostApiIntegrationsMspTenantsIdSubscription.
type PostApiIntegrationsMspTenantsIdSubscriptionJSONBody struct {
// PriceID The Price ID to change the subscription to.
PriceID string `json:"priceID"`
}
// PostApiIntegrationsMspTenantsIdUnlinkJSONBody defines parameters for PostApiIntegrationsMspTenantsIdUnlink.
type PostApiIntegrationsMspTenantsIdUnlinkJSONBody struct {
// Owner The new owners user ID.
Owner string `json:"owner"`
}
// GetApiPeersParams defines parameters for GetApiPeers.
type GetApiPeersParams struct {
// Name Filter peers by name
@@ -2452,6 +3143,12 @@ type PostApiDnsZonesZoneIdRecordsJSONRequestBody = DNSRecordRequest
// PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody defines body for PutApiDnsZonesZoneIdRecordsRecordId for application/json ContentType.
type PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody = DNSRecordRequest
// CreateIntegrationJSONRequestBody defines body for CreateIntegration for application/json ContentType.
type CreateIntegrationJSONRequestBody = CreateIntegrationRequest
// UpdateIntegrationJSONRequestBody defines body for UpdateIntegration for application/json ContentType.
type UpdateIntegrationJSONRequestBody = CreateIntegrationRequest
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
type PostApiGroupsJSONRequestBody = GroupRequest
@@ -2470,6 +3167,63 @@ type PostApiIngressPeersJSONRequestBody = IngressPeerCreateRequest
// PutApiIngressPeersIngressPeerIdJSONRequestBody defines body for PutApiIngressPeersIngressPeerId for application/json ContentType.
type PutApiIngressPeersIngressPeerIdJSONRequestBody = IngressPeerUpdateRequest
// PostApiIntegrationsBillingAwsMarketplaceActivateJSONRequestBody defines body for PostApiIntegrationsBillingAwsMarketplaceActivate for application/json ContentType.
type PostApiIntegrationsBillingAwsMarketplaceActivateJSONRequestBody PostApiIntegrationsBillingAwsMarketplaceActivateJSONBody
// PostApiIntegrationsBillingAwsMarketplaceEnrichJSONRequestBody defines body for PostApiIntegrationsBillingAwsMarketplaceEnrich for application/json ContentType.
type PostApiIntegrationsBillingAwsMarketplaceEnrichJSONRequestBody PostApiIntegrationsBillingAwsMarketplaceEnrichJSONBody
// PostApiIntegrationsBillingCheckoutJSONRequestBody defines body for PostApiIntegrationsBillingCheckout for application/json ContentType.
type PostApiIntegrationsBillingCheckoutJSONRequestBody PostApiIntegrationsBillingCheckoutJSONBody
// PutApiIntegrationsBillingSubscriptionJSONRequestBody defines body for PutApiIntegrationsBillingSubscription for application/json ContentType.
type PutApiIntegrationsBillingSubscriptionJSONRequestBody PutApiIntegrationsBillingSubscriptionJSONBody
// CreateFalconEDRIntegrationJSONRequestBody defines body for CreateFalconEDRIntegration for application/json ContentType.
type CreateFalconEDRIntegrationJSONRequestBody = EDRFalconRequest
// UpdateFalconEDRIntegrationJSONRequestBody defines body for UpdateFalconEDRIntegration for application/json ContentType.
type UpdateFalconEDRIntegrationJSONRequestBody = EDRFalconRequest
// CreateHuntressEDRIntegrationJSONRequestBody defines body for CreateHuntressEDRIntegration for application/json ContentType.
type CreateHuntressEDRIntegrationJSONRequestBody = EDRHuntressRequest
// UpdateHuntressEDRIntegrationJSONRequestBody defines body for UpdateHuntressEDRIntegration for application/json ContentType.
type UpdateHuntressEDRIntegrationJSONRequestBody = EDRHuntressRequest
// CreateEDRIntegrationJSONRequestBody defines body for CreateEDRIntegration for application/json ContentType.
type CreateEDRIntegrationJSONRequestBody = EDRIntuneRequest
// UpdateEDRIntegrationJSONRequestBody defines body for UpdateEDRIntegration for application/json ContentType.
type UpdateEDRIntegrationJSONRequestBody = EDRIntuneRequest
// CreateSentinelOneEDRIntegrationJSONRequestBody defines body for CreateSentinelOneEDRIntegration for application/json ContentType.
type CreateSentinelOneEDRIntegrationJSONRequestBody = EDRSentinelOneRequest
// UpdateSentinelOneEDRIntegrationJSONRequestBody defines body for UpdateSentinelOneEDRIntegration for application/json ContentType.
type UpdateSentinelOneEDRIntegrationJSONRequestBody = EDRSentinelOneRequest
// PostApiIntegrationsMspTenantsJSONRequestBody defines body for PostApiIntegrationsMspTenants for application/json ContentType.
type PostApiIntegrationsMspTenantsJSONRequestBody = CreateTenantRequest
// PutApiIntegrationsMspTenantsIdJSONRequestBody defines body for PutApiIntegrationsMspTenantsId for application/json ContentType.
type PutApiIntegrationsMspTenantsIdJSONRequestBody = UpdateTenantRequest
// PutApiIntegrationsMspTenantsIdInviteJSONRequestBody defines body for PutApiIntegrationsMspTenantsIdInvite for application/json ContentType.
type PutApiIntegrationsMspTenantsIdInviteJSONRequestBody PutApiIntegrationsMspTenantsIdInviteJSONBody
// PostApiIntegrationsMspTenantsIdSubscriptionJSONRequestBody defines body for PostApiIntegrationsMspTenantsIdSubscription for application/json ContentType.
type PostApiIntegrationsMspTenantsIdSubscriptionJSONRequestBody PostApiIntegrationsMspTenantsIdSubscriptionJSONBody
// PostApiIntegrationsMspTenantsIdUnlinkJSONRequestBody defines body for PostApiIntegrationsMspTenantsIdUnlink for application/json ContentType.
type PostApiIntegrationsMspTenantsIdUnlinkJSONRequestBody PostApiIntegrationsMspTenantsIdUnlinkJSONBody
// CreateSCIMIntegrationJSONRequestBody defines body for CreateSCIMIntegration for application/json ContentType.
type CreateSCIMIntegrationJSONRequestBody = CreateScimIntegrationRequest
// UpdateSCIMIntegrationJSONRequestBody defines body for UpdateSCIMIntegration for application/json ContentType.
type UpdateSCIMIntegrationJSONRequestBody = UpdateScimIntegrationRequest
// PostApiNetworksJSONRequestBody defines body for PostApiNetworks for application/json ContentType.
type PostApiNetworksJSONRequestBody = NetworkRequest

View File

@@ -225,35 +225,42 @@ func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, erro
c.mu.Unlock()
return nil, ErrConnAlreadyExists
}
c.mu.Unlock()
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
c.log.Errorf("peer not available: %s, %s", peerID, err)
return nil, err
}
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
msgChannel := make(chan Msg, 100)
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established")
}
c.log.Infof("prepare the relayed connection, waiting for remote peer: %s", peerID)
c.muInstanceURL.Lock()
instanceURL := c.instanceURL
c.muInstanceURL.Unlock()
conn := NewConn(c, peerID, msgChannel, instanceURL)
_, ok = c.conns[peerID]
if ok {
c.mu.Unlock()
_ = conn.Close()
return nil, ErrConnAlreadyExists
}
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
msgChannel := make(chan Msg, 100)
conn := NewConn(c, peerID, msgChannel, instanceURL)
container := newConnContainer(c.log, conn, msgChannel)
c.conns[peerID] = container
c.mu.Unlock()
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
c.log.Errorf("peer not available: %s, %s", peerID, err)
c.mu.Lock()
if savedContainer, ok := c.conns[peerID]; ok && savedContainer == container {
delete(c.conns, peerID)
}
c.mu.Unlock()
container.close()
return nil, err
}
c.mu.Lock()
if !c.serviceIsRunning {
if savedContainer, ok := c.conns[peerID]; ok && savedContainer == container {
delete(c.conns, peerID)
}
c.mu.Unlock()
container.close()
return nil, fmt.Errorf("relay connection is not established")
}
c.mu.Unlock()
c.log.Infof("remote peer is available: %s", peerID)
return conn, nil
}

View File

@@ -40,7 +40,6 @@ func Execute() error {
func init() {
stopCh = make(chan int)
defaultLogFile = "/var/log/netbird/signal.log"
defaultSignalSSLDir = "/var/lib/netbird/"
if runtime.GOOS == "windows" {
defaultLogFile = os.Getenv("PROGRAMDATA") + "\\Netbird\\" + "signal.log"

View File

@@ -18,7 +18,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/signal/proto"
@@ -38,13 +38,13 @@ import (
const legacyGRPCPort = 10000
var (
signalPort int
metricsPort int
signalLetsencryptDomain string
signalSSLDir string
defaultSignalSSLDir string
signalCertFile string
signalCertKey string
signalPort int
metricsPort int
signalLetsencryptDomain string
signalLetsencryptEmail string
signalLetsencryptDataDir string
signalCertFile string
signalCertKey string
signalKaep = grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
MinTime: 5 * time.Second,
@@ -216,7 +216,7 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config
}
if signalLetsencryptDomain != "" {
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
certManager, err = encryption.CreateCertManager(signalLetsencryptDataDir, signalLetsencryptDomain)
if err != nil {
return nil, certManager, nil, err
}
@@ -326,9 +326,11 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
func init() {
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
runCmd.Flags().IntVar(&metricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
runCmd.Flags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
runCmd.PersistentFlags().StringVar(&signalLetsencryptDataDir, "letsencrypt-data-dir", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
runCmd.PersistentFlags().StringVar(&signalLetsencryptDataDir, "ssl-dir", "", "server ssl directory location. *Required only for Let's Encrypt certificates. Deprecated: use --letsencrypt-data-dir")
runCmd.PersistentFlags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.PersistentFlags().StringVar(&signalLetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
runCmd.PersistentFlags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
runCmd.PersistentFlags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
setFlagsFromEnvVars(runCmd)
}

View File

@@ -24,15 +24,19 @@ type AppMetrics struct {
MessageSize metric.Int64Histogram
}
func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
activePeers, err := meter.Int64UpDownCounter("active_peers",
func NewAppMetrics(meter metric.Meter, prefix ...string) (*AppMetrics, error) {
p := ""
if len(prefix) > 0 {
p = prefix[0]
}
activePeers, err := meter.Int64UpDownCounter(p+"active_peers",
metric.WithDescription("Number of active connected peers"),
)
if err != nil {
return nil, err
}
peerConnectionDuration, err := meter.Int64Histogram("peer_connection_duration_seconds",
peerConnectionDuration, err := meter.Int64Histogram(p+"peer_connection_duration_seconds",
metric.WithExplicitBucketBoundaries(getPeerConnectionDurationBucketBoundaries()...),
metric.WithDescription("Duration of how long a peer was connected"),
)
@@ -40,28 +44,28 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
return nil, err
}
registrations, err := meter.Int64Counter("registrations_total",
registrations, err := meter.Int64Counter(p+"registrations_total",
metric.WithDescription("Total number of peer registrations"),
)
if err != nil {
return nil, err
}
deregistrations, err := meter.Int64Counter("deregistrations_total",
deregistrations, err := meter.Int64Counter(p+"deregistrations_total",
metric.WithDescription("Total number of peer deregistrations"),
)
if err != nil {
return nil, err
}
registrationFailures, err := meter.Int64Counter("registration_failures_total",
registrationFailures, err := meter.Int64Counter(p+"registration_failures_total",
metric.WithDescription("Total number of peer registration failures"),
)
if err != nil {
return nil, err
}
registrationDelay, err := meter.Float64Histogram("registration_delay_milliseconds",
registrationDelay, err := meter.Float64Histogram(p+"registration_delay_milliseconds",
metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...),
metric.WithDescription("Duration of how long it takes to register a peer"),
)
@@ -69,7 +73,7 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
return nil, err
}
getRegistrationDelay, err := meter.Float64Histogram("get_registration_delay_milliseconds",
getRegistrationDelay, err := meter.Float64Histogram(p+"get_registration_delay_milliseconds",
metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...),
metric.WithDescription("Duration of how long it takes to load a connection from the registry"),
)
@@ -77,21 +81,21 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
return nil, err
}
messagesForwarded, err := meter.Int64Counter("messages_forwarded_total",
messagesForwarded, err := meter.Int64Counter(p+"messages_forwarded_total",
metric.WithDescription("Total number of messages forwarded to peers"),
)
if err != nil {
return nil, err
}
messageForwardFailures, err := meter.Int64Counter("message_forward_failures_total",
messageForwardFailures, err := meter.Int64Counter(p+"message_forward_failures_total",
metric.WithDescription("Total number of message forwarding failures"),
)
if err != nil {
return nil, err
}
messageForwardLatency, err := meter.Float64Histogram("message_forward_latency_milliseconds",
messageForwardLatency, err := meter.Float64Histogram(p+"message_forward_latency_milliseconds",
metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...),
metric.WithDescription("Duration of how long it takes to forward a message to a peer"),
)
@@ -100,7 +104,7 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
}
messageSize, err := meter.Int64Histogram(
"message.size.bytes",
p+"message.size.bytes",
metric.WithUnit("bytes"),
metric.WithExplicitBucketBoundaries(getMessageSizeBucketBoundaries()...),
metric.WithDescription("Records the size of each message sent"),

View File

@@ -62,8 +62,8 @@ type Server struct {
}
// NewServer creates a new Signal server
func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
appMetrics, err := metrics.NewAppMetrics(meter)
func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string) (*Server, error) {
appMetrics, err := metrics.NewAppMetrics(meter, metricsPrefix...)
if err != nil {
return nil, fmt.Errorf("creating app metrics: %v", err)
}

View File

@@ -48,7 +48,7 @@ func NewServer(conns []*net.UDPConn, logLevel string) *Server {
// Use the formatter package to set up formatter, ReportCaller, and context hook
formatter.SetTextFormatter(stunLogger)
logger := stunLogger.WithField("component", "stun-server")
logger := stunLogger.WithField("component", "stun")
logger.Infof("STUN server log level set to: %s", level.String())
return &Server{