Compare commits

...

4 Commits

Author SHA1 Message Date
dependabot[bot]
8143834258 Bump github.com/pion/dtls/v3 from 3.0.9 to 3.0.11
Bumps [github.com/pion/dtls/v3](https://github.com/pion/dtls) from 3.0.9 to 3.0.11.
- [Release notes](https://github.com/pion/dtls/releases)
- [Commits](https://github.com/pion/dtls/compare/v3.0.9...v3.0.11)

---
updated-dependencies:
- dependency-name: github.com/pion/dtls/v3
  dependency-version: 3.0.11
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-02-13 14:20:47 +00:00
Bethuel Mmbaga
d3eeb6d8ee [misc] Add cloud api spec to public open api with rest client (#5222) 2026-02-13 15:08:47 +03:00
Bethuel Mmbaga
7ebf37ef20 [management] Enforce access control on accessible peers (#5301) 2026-02-13 12:46:43 +03:00
Misha Bragin
64b849c801 [self-hosted] add netbird server (#5232)
* Unified NetBird combined server (Management, Signal, Relay, STUN) as a single executable with richer YAML configuration, validation, and defaults.
  * Official Dockerfile/image for single-container deployment.
  * Optional in-process profiling endpoint for diagnostics.
  * Multiplexing to route HTTP/gRPC/WebSocket traffic via one port; runtime hooks to inject custom handlers.
* **Chores**
  * Updated deployment scripts, compose files, and reverse-proxy templates to target the combined server; added example configs and getting-started updates.
2026-02-12 19:24:43 +01:00
54 changed files with 9642 additions and 707 deletions

View File

@@ -106,6 +106,26 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-server
dir: combined
env:
- CGO_ENABLED=1
- >-
{{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }}
binary: netbird-server
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
@@ -520,6 +540,55 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
@@ -598,6 +667,18 @@ docker_manifests:
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
@@ -675,6 +756,19 @@ docker_manifests:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
brews:
- ids:
- default

5
combined/Dockerfile Normal file
View File

@@ -0,0 +1,5 @@
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-server" ]
CMD ["--config", "/etc/netbird/config.yaml"]
COPY netbird-server /go/bin/netbird-server

715
combined/cmd/config.go Normal file
View File

@@ -0,0 +1,715 @@
package cmd
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"path"
"strings"
"time"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/crypt"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
// CombinedConfig is the root configuration for the combined server.
// The combined server is primarily a Management server with optional embedded
// Signal, Relay, and STUN services.
//
// Architecture:
// - Management: Always runs locally (this IS the management server)
// - Signal: Runs locally by default; disabled if server.signalUri is set
// - Relay: Runs locally by default; disabled if server.relays is set
// - STUN: Runs locally on port 3478 by default; disabled if server.stuns is set
//
// All user-facing settings are under "server". The relay/signal/management
// fields are internal and populated automatically from server settings.
type CombinedConfig struct {
Server ServerConfig `yaml:"server"`
// Internal configs - populated from Server settings, not user-configurable
Relay RelayConfig `yaml:"-"`
Signal SignalConfig `yaml:"-"`
Management ManagementConfig `yaml:"-"`
}
// ServerConfig contains server-wide settings
// In simplified mode, this contains all configuration
type ServerConfig struct {
ListenAddress string `yaml:"listenAddress"`
MetricsPort int `yaml:"metricsPort"`
HealthcheckAddress string `yaml:"healthcheckAddress"`
LogLevel string `yaml:"logLevel"`
LogFile string `yaml:"logFile"`
TLS TLSConfig `yaml:"tls"`
// Simplified config fields (used when relay/signal/management sections are omitted)
ExposedAddress string `yaml:"exposedAddress"` // Public address with protocol (e.g., "https://example.com:443")
StunPorts []int `yaml:"stunPorts"` // STUN ports (empty to disable local STUN)
AuthSecret string `yaml:"authSecret"` // Shared secret for relay authentication
DataDir string `yaml:"dataDir"` // Data directory for all services
// External service overrides (simplified mode)
// When these are set, the corresponding local service is NOT started
// and these values are used for client configuration instead
Stuns []HostConfig `yaml:"stuns"` // External STUN servers (disables local STUN)
Relays RelaysConfig `yaml:"relays"` // External relay servers (disables local relay)
SignalURI string `yaml:"signalUri"` // External signal server (disables local signal)
// Management settings (simplified mode)
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
Auth AuthConfig `yaml:"auth"`
Store StoreConfig `yaml:"store"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}
// TLSConfig contains TLS/HTTPS settings
type TLSConfig struct {
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
LetsEncrypt LetsEncryptConfig `yaml:"letsencrypt"`
}
// LetsEncryptConfig contains Let's Encrypt settings
type LetsEncryptConfig struct {
Enabled bool `yaml:"enabled"`
DataDir string `yaml:"dataDir"`
Domains []string `yaml:"domains"`
Email string `yaml:"email"`
AWSRoute53 bool `yaml:"awsRoute53"`
}
// RelayConfig contains relay service settings
type RelayConfig struct {
Enabled bool `yaml:"enabled"`
ExposedAddress string `yaml:"exposedAddress"`
AuthSecret string `yaml:"authSecret"`
LogLevel string `yaml:"logLevel"`
Stun StunConfig `yaml:"stun"`
}
// StunConfig contains embedded STUN service settings
type StunConfig struct {
Enabled bool `yaml:"enabled"`
Ports []int `yaml:"ports"`
LogLevel string `yaml:"logLevel"`
}
// SignalConfig contains signal service settings
type SignalConfig struct {
Enabled bool `yaml:"enabled"`
LogLevel string `yaml:"logLevel"`
}
// ManagementConfig contains management service settings
type ManagementConfig struct {
Enabled bool `yaml:"enabled"`
LogLevel string `yaml:"logLevel"`
DataDir string `yaml:"dataDir"`
DnsDomain string `yaml:"dnsDomain"`
DisableAnonymousMetrics bool `yaml:"disableAnonymousMetrics"`
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
DisableDefaultPolicy bool `yaml:"disableDefaultPolicy"`
Auth AuthConfig `yaml:"auth"`
Stuns []HostConfig `yaml:"stuns"`
Relays RelaysConfig `yaml:"relays"`
SignalURI string `yaml:"signalUri"`
Store StoreConfig `yaml:"store"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}
// AuthConfig contains authentication/identity provider settings
type AuthConfig struct {
Issuer string `yaml:"issuer"`
LocalAuthDisabled bool `yaml:"localAuthDisabled"`
SignKeyRefreshEnabled bool `yaml:"signKeyRefreshEnabled"`
Storage AuthStorageConfig `yaml:"storage"`
DashboardRedirectURIs []string `yaml:"dashboardRedirectURIs"`
CLIRedirectURIs []string `yaml:"cliRedirectURIs"`
Owner *AuthOwnerConfig `yaml:"owner,omitempty"`
}
// AuthStorageConfig contains auth storage settings
type AuthStorageConfig struct {
Type string `yaml:"type"`
File string `yaml:"file"`
}
// AuthOwnerConfig contains initial admin user settings
type AuthOwnerConfig struct {
Email string `yaml:"email"`
Password string `yaml:"password"`
}
// HostConfig represents a STUN/TURN/Signal host
type HostConfig struct {
URI string `yaml:"uri"`
Proto string `yaml:"proto,omitempty"` // udp, dtls, tcp, http, https - defaults based on URI scheme
Username string `yaml:"username,omitempty"`
Password string `yaml:"password,omitempty"`
}
// RelaysConfig contains external relay server settings for clients
type RelaysConfig struct {
Addresses []string `yaml:"addresses"`
CredentialsTTL string `yaml:"credentialsTTL"`
Secret string `yaml:"secret"`
}
// StoreConfig contains database settings
type StoreConfig struct {
Engine string `yaml:"engine"`
EncryptionKey string `yaml:"encryptionKey"`
DSN string `yaml:"dsn"` // Connection string for postgres or mysql engines
}
// ReverseProxyConfig contains reverse proxy settings
type ReverseProxyConfig struct {
TrustedHTTPProxies []string `yaml:"trustedHTTPProxies"`
TrustedHTTPProxiesCount uint `yaml:"trustedHTTPProxiesCount"`
TrustedPeers []string `yaml:"trustedPeers"`
}
// DefaultConfig returns a CombinedConfig with default values
func DefaultConfig() *CombinedConfig {
return &CombinedConfig{
Server: ServerConfig{
ListenAddress: ":443",
MetricsPort: 9090,
HealthcheckAddress: ":9000",
LogLevel: "info",
LogFile: "console",
StunPorts: []int{3478},
DataDir: "/var/lib/netbird/",
Auth: AuthConfig{
Storage: AuthStorageConfig{
Type: "sqlite3",
},
},
Store: StoreConfig{
Engine: "sqlite",
},
},
Relay: RelayConfig{
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
Stun: StunConfig{
Enabled: false,
Ports: []int{3478},
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
},
},
Signal: SignalConfig{
// LogLevel inherited from Server.LogLevel via ApplySimplifiedDefaults
},
Management: ManagementConfig{
DataDir: "/var/lib/netbird/",
Auth: AuthConfig{
Storage: AuthStorageConfig{
Type: "sqlite3",
},
},
Relays: RelaysConfig{
CredentialsTTL: "12h",
},
Store: StoreConfig{
Engine: "sqlite",
},
},
}
}
// hasRequiredSettings returns true if the configuration has the required server settings
func (c *CombinedConfig) hasRequiredSettings() bool {
return c.Server.ExposedAddress != ""
}
// parseExposedAddress extracts protocol, host, and host:port from the exposed address
// Input format: "https://example.com:443" or "http://example.com:8080" or "example.com:443"
// Returns: protocol ("https" or "http"), hostname only, and host:port
func parseExposedAddress(exposedAddress string) (protocol, hostname, hostPort string) {
// Default to https if no protocol specified
protocol = "https"
hostPort = exposedAddress
// Check for protocol prefix
if strings.HasPrefix(exposedAddress, "https://") {
protocol = "https"
hostPort = strings.TrimPrefix(exposedAddress, "https://")
} else if strings.HasPrefix(exposedAddress, "http://") {
protocol = "http"
hostPort = strings.TrimPrefix(exposedAddress, "http://")
}
// Extract hostname (without port)
hostname = hostPort
if host, _, err := net.SplitHostPort(hostPort); err == nil {
hostname = host
}
return protocol, hostname, hostPort
}
// ApplySimplifiedDefaults populates internal relay/signal/management configs from server settings.
// Management is always enabled. Signal, Relay, and STUN are enabled unless external
// overrides are configured (server.signalUri, server.relays, server.stuns).
func (c *CombinedConfig) ApplySimplifiedDefaults() {
if !c.hasRequiredSettings() {
return
}
// Parse exposed address to extract protocol and hostname
exposedProto, exposedHost, exposedHostPort := parseExposedAddress(c.Server.ExposedAddress)
// Check for external service overrides
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
hasExternalSignal := c.Server.SignalURI != ""
hasExternalStuns := len(c.Server.Stuns) > 0
// Default stunPorts to [3478] if not specified and no external STUN
if len(c.Server.StunPorts) == 0 && !hasExternalStuns {
c.Server.StunPorts = []int{3478}
}
c.applyRelayDefaults(exposedProto, exposedHostPort, hasExternalRelay, hasExternalStuns)
c.applySignalDefaults(hasExternalSignal)
c.applyManagementDefaults(exposedHost)
// Auto-configure client settings (stuns, relays, signalUri)
c.autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort, hasExternalStuns, hasExternalRelay, hasExternalSignal)
}
// applyRelayDefaults configures the relay service if no external relay is configured.
func (c *CombinedConfig) applyRelayDefaults(exposedProto, exposedHostPort string, hasExternalRelay, hasExternalStuns bool) {
if hasExternalRelay {
return
}
c.Relay.Enabled = true
relayProto := "rel"
if exposedProto == "https" {
relayProto = "rels"
}
c.Relay.ExposedAddress = fmt.Sprintf("%s://%s", relayProto, exposedHostPort)
c.Relay.AuthSecret = c.Server.AuthSecret
if c.Relay.LogLevel == "" {
c.Relay.LogLevel = c.Server.LogLevel
}
// Enable local STUN only if no external STUN servers and stunPorts are configured
if !hasExternalStuns && len(c.Server.StunPorts) > 0 {
c.Relay.Stun.Enabled = true
c.Relay.Stun.Ports = c.Server.StunPorts
if c.Relay.Stun.LogLevel == "" {
c.Relay.Stun.LogLevel = c.Server.LogLevel
}
}
}
// applySignalDefaults configures the signal service if no external signal is configured.
func (c *CombinedConfig) applySignalDefaults(hasExternalSignal bool) {
if hasExternalSignal {
return
}
c.Signal.Enabled = true
if c.Signal.LogLevel == "" {
c.Signal.LogLevel = c.Server.LogLevel
}
}
// applyManagementDefaults configures the management service (always enabled).
func (c *CombinedConfig) applyManagementDefaults(exposedHost string) {
c.Management.Enabled = true
if c.Management.LogLevel == "" {
c.Management.LogLevel = c.Server.LogLevel
}
if c.Management.DataDir == "" || c.Management.DataDir == "/var/lib/netbird/" {
c.Management.DataDir = c.Server.DataDir
}
c.Management.DnsDomain = exposedHost
c.Management.DisableAnonymousMetrics = c.Server.DisableAnonymousMetrics
c.Management.DisableGeoliteUpdate = c.Server.DisableGeoliteUpdate
// Copy auth config from server if management auth issuer is not set
if c.Management.Auth.Issuer == "" && c.Server.Auth.Issuer != "" {
c.Management.Auth = c.Server.Auth
}
// Copy store config from server if not set
if c.Management.Store.Engine == "" || c.Management.Store.Engine == "sqlite" {
if c.Server.Store.Engine != "" {
c.Management.Store = c.Server.Store
}
}
// Copy reverse proxy config from server
if len(c.Server.ReverseProxy.TrustedHTTPProxies) > 0 || c.Server.ReverseProxy.TrustedHTTPProxiesCount > 0 || len(c.Server.ReverseProxy.TrustedPeers) > 0 {
c.Management.ReverseProxy = c.Server.ReverseProxy
}
}
// autoConfigureClientSettings sets up STUN/relay/signal URIs for clients
// External overrides from server config take precedence over auto-generated values
func (c *CombinedConfig) autoConfigureClientSettings(exposedProto, exposedHost, exposedHostPort string, hasExternalStuns, hasExternalRelay, hasExternalSignal bool) {
// Determine relay protocol from exposed protocol
relayProto := "rel"
if exposedProto == "https" {
relayProto = "rels"
}
// Configure STUN servers for clients
if hasExternalStuns {
// Use external STUN servers from server config
c.Management.Stuns = c.Server.Stuns
} else if len(c.Server.StunPorts) > 0 && len(c.Management.Stuns) == 0 {
// Auto-configure local STUN servers for all ports
for _, port := range c.Server.StunPorts {
c.Management.Stuns = append(c.Management.Stuns, HostConfig{
URI: fmt.Sprintf("stun:%s:%d", exposedHost, port),
})
}
}
// Configure relay for clients
if hasExternalRelay {
// Use external relay config from server
c.Management.Relays = c.Server.Relays
} else if len(c.Management.Relays.Addresses) == 0 {
// Auto-configure local relay
c.Management.Relays.Addresses = []string{
fmt.Sprintf("%s://%s", relayProto, exposedHostPort),
}
}
if c.Management.Relays.Secret == "" {
c.Management.Relays.Secret = c.Server.AuthSecret
}
if c.Management.Relays.CredentialsTTL == "" {
c.Management.Relays.CredentialsTTL = "12h"
}
// Configure signal for clients
if hasExternalSignal {
// Use external signal URI from server config
c.Management.SignalURI = c.Server.SignalURI
} else if c.Management.SignalURI == "" {
// Auto-configure local signal
c.Management.SignalURI = fmt.Sprintf("%s://%s", exposedProto, exposedHostPort)
}
}
// LoadConfig loads configuration from a YAML file
func LoadConfig(configPath string) (*CombinedConfig, error) {
cfg := DefaultConfig()
if configPath == "" {
return cfg, nil
}
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
if err := yaml.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
// Populate internal configs from server settings
cfg.ApplySimplifiedDefaults()
return cfg, nil
}
// Validate validates the configuration
func (c *CombinedConfig) Validate() error {
if c.Server.ExposedAddress == "" {
return fmt.Errorf("server.exposedAddress is required")
}
if c.Server.DataDir == "" {
return fmt.Errorf("server.dataDir is required")
}
// Validate STUN ports
seen := make(map[int]bool)
for _, port := range c.Server.StunPorts {
if port <= 0 || port > 65535 {
return fmt.Errorf("invalid server.stunPorts value %d: must be between 1 and 65535", port)
}
if seen[port] {
return fmt.Errorf("duplicate STUN port %d in server.stunPorts", port)
}
seen[port] = true
}
// authSecret is required only if running local relay (no external relay configured)
hasExternalRelay := len(c.Server.Relays.Addresses) > 0
if !hasExternalRelay && c.Server.AuthSecret == "" {
return fmt.Errorf("server.authSecret is required when running local relay")
}
return nil
}
// HasTLSCert returns true if TLS certificate files are configured
func (c *CombinedConfig) HasTLSCert() bool {
return c.Server.TLS.CertFile != "" && c.Server.TLS.KeyFile != ""
}
// HasLetsEncrypt returns true if Let's Encrypt is configured
func (c *CombinedConfig) HasLetsEncrypt() bool {
return c.Server.TLS.LetsEncrypt.Enabled &&
c.Server.TLS.LetsEncrypt.DataDir != "" &&
len(c.Server.TLS.LetsEncrypt.Domains) > 0
}
// parseExplicitProtocol parses an explicit protocol string to nbconfig.Protocol
func parseExplicitProtocol(proto string) (nbconfig.Protocol, bool) {
switch strings.ToLower(proto) {
case "udp":
return nbconfig.UDP, true
case "dtls":
return nbconfig.DTLS, true
case "tcp":
return nbconfig.TCP, true
case "http":
return nbconfig.HTTP, true
case "https":
return nbconfig.HTTPS, true
default:
return "", false
}
}
// parseStunProtocol determines protocol for STUN/TURN servers.
// stun: → UDP, stuns: → DTLS, turn: → UDP, turns: → DTLS
// Explicit proto overrides URI scheme. Defaults to UDP.
func parseStunProtocol(uri, proto string) nbconfig.Protocol {
if proto != "" {
if p, ok := parseExplicitProtocol(proto); ok {
return p
}
}
uri = strings.ToLower(uri)
switch {
case strings.HasPrefix(uri, "stuns:"):
return nbconfig.DTLS
case strings.HasPrefix(uri, "turns:"):
return nbconfig.DTLS
default:
// stun:, turn:, or no scheme - default to UDP
return nbconfig.UDP
}
}
// parseSignalProtocol determines protocol for Signal servers.
// https:// → HTTPS, http:// → HTTP. Defaults to HTTPS.
func parseSignalProtocol(uri string) nbconfig.Protocol {
uri = strings.ToLower(uri)
switch {
case strings.HasPrefix(uri, "http://"):
return nbconfig.HTTP
default:
// https:// or no scheme - default to HTTPS
return nbconfig.HTTPS
}
}
// stripSignalProtocol removes the protocol prefix from a signal URI.
// Returns just the host:port (e.g., "selfhosted2.demo.netbird.io:443").
func stripSignalProtocol(uri string) string {
uri = strings.TrimPrefix(uri, "https://")
uri = strings.TrimPrefix(uri, "http://")
return uri
}
// ToManagementConfig converts CombinedConfig to management server config
func (c *CombinedConfig) ToManagementConfig() (*nbconfig.Config, error) {
mgmt := c.Management
// Build STUN hosts
var stuns []*nbconfig.Host
for _, s := range mgmt.Stuns {
stuns = append(stuns, &nbconfig.Host{
URI: s.URI,
Proto: parseStunProtocol(s.URI, s.Proto),
Username: s.Username,
Password: s.Password,
})
}
// Build relay config
var relayConfig *nbconfig.Relay
if len(mgmt.Relays.Addresses) > 0 || mgmt.Relays.Secret != "" {
var ttl time.Duration
if mgmt.Relays.CredentialsTTL != "" {
var err error
ttl, err = time.ParseDuration(mgmt.Relays.CredentialsTTL)
if err != nil {
return nil, fmt.Errorf("invalid relay credentials TTL %q: %w", mgmt.Relays.CredentialsTTL, err)
}
}
relayConfig = &nbconfig.Relay{
Addresses: mgmt.Relays.Addresses,
CredentialsTTL: util.Duration{Duration: ttl},
Secret: mgmt.Relays.Secret,
}
}
// Build signal config
var signalConfig *nbconfig.Host
if mgmt.SignalURI != "" {
signalConfig = &nbconfig.Host{
URI: stripSignalProtocol(mgmt.SignalURI),
Proto: parseSignalProtocol(mgmt.SignalURI),
}
}
// Build store config
storeConfig := nbconfig.StoreConfig{
Engine: types.Engine(mgmt.Store.Engine),
}
// Build reverse proxy config
reverseProxy := nbconfig.ReverseProxy{
TrustedHTTPProxiesCount: mgmt.ReverseProxy.TrustedHTTPProxiesCount,
}
for _, p := range mgmt.ReverseProxy.TrustedHTTPProxies {
if prefix, err := netip.ParsePrefix(p); err == nil {
reverseProxy.TrustedHTTPProxies = append(reverseProxy.TrustedHTTPProxies, prefix)
}
}
for _, p := range mgmt.ReverseProxy.TrustedPeers {
if prefix, err := netip.ParsePrefix(p); err == nil {
reverseProxy.TrustedPeers = append(reverseProxy.TrustedPeers, prefix)
}
}
// Build HTTP config (required, even if empty)
httpConfig := &nbconfig.HttpServerConfig{}
// Build embedded IDP config (always enabled in combined server)
storageFile := mgmt.Auth.Storage.File
if storageFile == "" {
storageFile = path.Join(mgmt.DataDir, "idp.db")
}
embeddedIdP := &idp.EmbeddedIdPConfig{
Enabled: true,
Issuer: mgmt.Auth.Issuer,
LocalAuthDisabled: mgmt.Auth.LocalAuthDisabled,
SignKeyRefreshEnabled: mgmt.Auth.SignKeyRefreshEnabled,
Storage: idp.EmbeddedStorageConfig{
Type: mgmt.Auth.Storage.Type,
Config: idp.EmbeddedStorageTypeConfig{
File: storageFile,
},
},
DashboardRedirectURIs: mgmt.Auth.DashboardRedirectURIs,
CLIRedirectURIs: mgmt.Auth.CLIRedirectURIs,
}
if mgmt.Auth.Owner != nil && mgmt.Auth.Owner.Email != "" {
embeddedIdP.Owner = &idp.OwnerConfig{
Email: mgmt.Auth.Owner.Email,
Hash: mgmt.Auth.Owner.Password, // Will be hashed if plain text
}
}
// Set HTTP config fields for embedded IDP
httpConfig.AuthIssuer = mgmt.Auth.Issuer
httpConfig.IdpSignKeyRefreshEnabled = mgmt.Auth.SignKeyRefreshEnabled
return &nbconfig.Config{
Stuns: stuns,
Relay: relayConfig,
Signal: signalConfig,
Datadir: mgmt.DataDir,
DataStoreEncryptionKey: mgmt.Store.EncryptionKey,
HttpConfig: httpConfig,
StoreConfig: storeConfig,
ReverseProxy: reverseProxy,
DisableDefaultPolicy: mgmt.DisableDefaultPolicy,
EmbeddedIdP: embeddedIdP,
}, nil
}
// ApplyEmbeddedIdPConfig applies embedded IdP configuration to the management config.
// This mirrors the logic in management/cmd/management.go ApplyEmbeddedIdPConfig.
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config, mgmtPort int, disableSingleAccMode bool) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
// Embedded IdP requires single account mode
if disableSingleAccMode {
return fmt.Errorf("embedded IdP requires single account mode; multiple account mode is not supported with embedded IdP")
}
// Set LocalAddress for embedded IdP, used for internal JWT validation
cfg.EmbeddedIdP.LocalAddress = fmt.Sprintf("localhost:%d", mgmtPort)
// Set storage defaults based on Datadir
if cfg.EmbeddedIdP.Storage.Type == "" {
cfg.EmbeddedIdP.Storage.Type = "sqlite3"
}
if cfg.EmbeddedIdP.Storage.Config.File == "" && cfg.Datadir != "" {
cfg.EmbeddedIdP.Storage.Config.File = path.Join(cfg.Datadir, "idp.db")
}
issuer := cfg.EmbeddedIdP.Issuer
// Ensure HttpConfig exists
if cfg.HttpConfig == nil {
cfg.HttpConfig = &nbconfig.HttpServerConfig{}
}
// Set HttpConfig values from EmbeddedIdP
cfg.HttpConfig.AuthIssuer = issuer
cfg.HttpConfig.AuthAudience = "netbird-dashboard"
cfg.HttpConfig.CLIAuthAudience = "netbird-cli"
cfg.HttpConfig.AuthUserIDClaim = "sub"
cfg.HttpConfig.AuthKeysLocation = issuer + "/keys"
cfg.HttpConfig.OIDCConfigEndpoint = issuer + "/.well-known/openid-configuration"
cfg.HttpConfig.IdpSignKeyRefreshEnabled = true
return nil
}
// EnsureEncryptionKey generates an encryption key if not set.
// Unlike management server, we don't write back to the config file.
func EnsureEncryptionKey(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.DataStoreEncryptionKey != "" {
return nil
}
log.WithContext(ctx).Infof("DataStoreEncryptionKey is not set, generating a new key")
key, err := crypt.GenerateKey()
if err != nil {
return fmt.Errorf("failed to generate datastore encryption key: %v", err)
}
cfg.DataStoreEncryptionKey = key
keyPreview := key[:8] + "..."
log.WithContext(ctx).Warnf("DataStoreEncryptionKey generated (%s); add it to your config file under 'server.store.encryptionKey' to persist across restarts", keyPreview)
return nil
}
// LogConfigInfo logs informational messages about the loaded configuration
func LogConfigInfo(cfg *nbconfig.Config) {
if cfg.EmbeddedIdP != nil && cfg.EmbeddedIdP.Enabled {
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
}
if cfg.Relay != nil {
log.Infof("Relay addresses: %v", cfg.Relay.Addresses)
}
}

33
combined/cmd/pprof.go Normal file
View File

@@ -0,0 +1,33 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

711
combined/cmd/root.go Normal file
View File

@@ -0,0 +1,711 @@
package cmd
import (
"context"
"crypto/sha256"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/coder/websocket"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"go.opentelemetry.io/otel/metric"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/encryption"
mgmtServer "github.com/netbirdio/netbird/management/internals/server"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/relay/healthcheck"
relayServer "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/relay/server/listener/ws"
sharedMetrics "github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/stun"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/wsproxy"
wsproxyserver "github.com/netbirdio/netbird/util/wsproxy/server"
)
var (
configPath string
config *CombinedConfig
rootCmd = &cobra.Command{
Use: "combined",
Short: "Combined Netbird server (Management + Signal + Relay + STUN)",
Long: `Combined Netbird server for self-hosted deployments.
All services (Management, Signal, Relay) are multiplexed on a single port.
Optional STUN server runs on separate UDP ports.
Configuration is loaded from a YAML file specified with --config.`,
SilenceUsage: true,
SilenceErrors: true,
RunE: execute,
}
)
func init() {
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
_ = rootCmd.MarkPersistentFlagRequired("config")
}
func Execute() error {
return rootCmd.Execute()
}
func waitForExitSignal() {
osSigs := make(chan os.Signal, 1)
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
<-osSigs
}
func execute(cmd *cobra.Command, _ []string) error {
if err := initializeConfig(); err != nil {
return err
}
// Management is required as the base server when signal or relay are enabled
if (config.Signal.Enabled || config.Relay.Enabled) && !config.Management.Enabled {
return fmt.Errorf("management must be enabled when signal or relay are enabled (provides the base HTTP server)")
}
servers, err := createAllServers(cmd.Context(), config)
if err != nil {
return err
}
// Register services with management's gRPC server using AfterInit hook
setupServerHooks(servers, config)
// Start management server (this also starts the HTTP listener)
if servers.mgmtSrv != nil {
if err := servers.mgmtSrv.Start(cmd.Context()); err != nil {
cleanupSTUNListeners(servers.stunListeners)
return fmt.Errorf("failed to start management server: %w", err)
}
}
// Start all other servers
wg := sync.WaitGroup{}
startServers(&wg, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.metricsServer)
waitForExitSignal()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = shutdownServers(ctx, servers.relaySrv, servers.healthcheck, servers.stunServer, servers.mgmtSrv, servers.metricsServer)
wg.Wait()
return err
}
// initializeConfig loads and validates the configuration, then initializes logging.
func initializeConfig() error {
var err error
config, err = LoadConfig(configPath)
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
if err := config.Validate(); err != nil {
return fmt.Errorf("invalid config: %w", err)
}
if err := util.InitLog(config.Server.LogLevel, config.Server.LogFile); err != nil {
return fmt.Errorf("failed to initialize log: %w", err)
}
if dsn := config.Server.Store.DSN; dsn != "" {
switch strings.ToLower(config.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
log.Infof("Starting combined NetBird server")
logConfig(config)
logEnvVars()
return nil
}
// serverInstances holds all server instances created during startup.
type serverInstances struct {
relaySrv *relayServer.Server
mgmtSrv *mgmtServer.BaseServer
signalSrv *signalServer.Server
healthcheck *healthcheck.Server
stunServer *stun.Server
stunListeners []*net.UDPConn
metricsServer *sharedMetrics.Metrics
}
// createAllServers creates all server instances based on configuration.
func createAllServers(ctx context.Context, cfg *CombinedConfig) (*serverInstances, error) {
metricsServer, err := sharedMetrics.NewServer(cfg.Server.MetricsPort, "")
if err != nil {
return nil, fmt.Errorf("failed to create metrics server: %w", err)
}
servers := &serverInstances{
metricsServer: metricsServer,
}
_, tlsSupport, err := handleTLSConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to setup TLS config: %w", err)
}
if err := servers.createRelayServer(cfg, tlsSupport); err != nil {
return nil, err
}
if err := servers.createManagementServer(ctx, cfg); err != nil {
return nil, err
}
if err := servers.createSignalServer(ctx, cfg); err != nil {
return nil, err
}
if err := servers.createHealthcheckServer(cfg); err != nil {
return nil, err
}
return servers, nil
}
func (s *serverInstances) createRelayServer(cfg *CombinedConfig, tlsSupport bool) error {
if !cfg.Relay.Enabled {
return nil
}
var err error
s.stunListeners, err = createSTUNListeners(cfg)
if err != nil {
return err
}
hashedSecret := sha256.Sum256([]byte(cfg.Relay.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
relayCfg := relayServer.Config{
Meter: s.metricsServer.Meter,
ExposedAddress: cfg.Relay.ExposedAddress,
AuthValidator: authenticator,
TLSSupport: tlsSupport,
}
s.relaySrv, err = createRelayServer(relayCfg, s.stunListeners)
if err != nil {
return err
}
log.Infof("Relay server created")
if len(s.stunListeners) > 0 {
s.stunServer = stun.NewServer(s.stunListeners, cfg.Relay.Stun.LogLevel)
}
return nil
}
func (s *serverInstances) createManagementServer(ctx context.Context, cfg *CombinedConfig) error {
if !cfg.Management.Enabled {
return nil
}
mgmtConfig, err := cfg.ToManagementConfig()
if err != nil {
return fmt.Errorf("failed to create management config: %w", err)
}
_, portStr, portErr := net.SplitHostPort(cfg.Server.ListenAddress)
if portErr != nil {
portStr = "443"
}
mgmtPort, _ := strconv.Atoi(portStr)
if err := ApplyEmbeddedIdPConfig(ctx, mgmtConfig, mgmtPort, false); err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to apply embedded IdP config: %w", err)
}
if err := EnsureEncryptionKey(ctx, mgmtConfig); err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to ensure encryption key: %w", err)
}
LogConfigInfo(mgmtConfig)
s.mgmtSrv, err = createManagementServer(cfg, mgmtConfig)
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create management server: %w", err)
}
// Inject externally-managed AppMetrics so management uses the shared metrics server
appMetrics, err := telemetry.NewAppMetricsWithMeter(ctx, s.metricsServer.Meter)
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create management app metrics: %w", err)
}
mgmtServer.Inject[telemetry.AppMetrics](s.mgmtSrv, appMetrics)
log.Infof("Management server created")
return nil
}
func (s *serverInstances) createSignalServer(ctx context.Context, cfg *CombinedConfig) error {
if !cfg.Signal.Enabled {
return nil
}
var err error
s.signalSrv, err = signalServer.NewServer(ctx, s.metricsServer.Meter, "signal_")
if err != nil {
cleanupSTUNListeners(s.stunListeners)
return fmt.Errorf("failed to create signal server: %w", err)
}
log.Infof("Signal server created")
return nil
}
func (s *serverInstances) createHealthcheckServer(cfg *CombinedConfig) error {
hCfg := healthcheck.Config{
ListenAddress: cfg.Server.HealthcheckAddress,
ServiceChecker: s.relaySrv,
}
var err error
s.healthcheck, err = createHealthCheck(hCfg, s.stunListeners)
return err
}
// setupServerHooks registers services with management's gRPC server.
func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
if servers.mgmtSrv == nil {
return
}
servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) {
grpcSrv := s.GRPCServer()
if servers.signalSrv != nil {
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
}
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
if servers.relaySrv != nil {
log.Infof("Relay WebSocket handler added (path: /relay)")
}
})
}
func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) {
if srv != nil {
instanceURL := srv.InstanceURL()
log.Infof("Relay server instance URL: %s", instanceURL.String())
log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)")
}
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start metrics server: %v", err)
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start healthcheck server: %v", err)
}
}()
if stunServer != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := stunServer.Listen(); err != nil {
if errors.Is(err, stun.ErrServerClosed) {
return
}
log.Errorf("STUN server error: %v", err)
}
}()
}
}
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error {
var errs error
if err := httpHealthcheck.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
}
if stunServer != nil {
if err := stunServer.Shutdown(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
}
}
if srv != nil {
if err := srv.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
}
}
if mgmtSrv != nil {
log.Infof("shutting down management and signal servers")
if err := mgmtSrv.Stop(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close management server: %w", err))
}
}
if metricsServer != nil {
log.Infof("shutting down metrics server")
if err := metricsServer.Shutdown(ctx); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
}
}
return errs
}
func createHealthCheck(hCfg healthcheck.Config, stunListeners []*net.UDPConn) (*healthcheck.Server, error) {
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create healthcheck server: %w", err)
}
return httpHealthcheck, nil
}
func createRelayServer(cfg relayServer.Config, stunListeners []*net.UDPConn) (*relayServer.Server, error) {
srv, err := relayServer.NewServer(cfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create relay server: %w", err)
}
return srv, nil
}
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
for _, l := range stunListeners {
_ = l.Close()
}
}
func createSTUNListeners(cfg *CombinedConfig) ([]*net.UDPConn, error) {
var stunListeners []*net.UDPConn
if cfg.Relay.Stun.Enabled {
for _, port := range cfg.Relay.Stun.Ports {
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
if err != nil {
cleanupSTUNListeners(stunListeners)
return nil, fmt.Errorf("failed to create STUN listener on port %d: %w", port, err)
}
stunListeners = append(stunListeners, listener)
log.Infof("STUN server listening on UDP port %d", port)
}
}
return stunListeners, nil
}
func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
tlsCfg := cfg.Server.TLS
if tlsCfg.LetsEncrypt.AWSRoute53 {
log.Debugf("using Let's Encrypt DNS resolver with Route 53 support")
r53 := encryption.Route53TLS{
DataDir: tlsCfg.LetsEncrypt.DataDir,
Email: tlsCfg.LetsEncrypt.Email,
Domains: tlsCfg.LetsEncrypt.Domains,
}
tc, err := r53.GetCertificate()
if err != nil {
return nil, false, err
}
return tc, true, nil
}
if cfg.HasLetsEncrypt() {
log.Infof("setting up TLS with Let's Encrypt")
certManager, err := encryption.CreateCertManager(tlsCfg.LetsEncrypt.DataDir, tlsCfg.LetsEncrypt.Domains...)
if err != nil {
return nil, false, fmt.Errorf("failed creating LetsEncrypt cert manager: %w", err)
}
return certManager.TLSConfig(), true, nil
}
if cfg.HasTLSCert() {
log.Debugf("using file based TLS config")
tc, err := encryption.LoadTLSConfig(tlsCfg.CertFile, tlsCfg.KeyFile)
if err != nil {
return nil, false, err
}
return tc, true, nil
}
return nil, false, nil
}
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
mgmt := cfg.Management
dnsDomain := mgmt.DnsDomain
singleAccModeDomain := dnsDomain
// Extract port from listen address
_, portStr, err := net.SplitHostPort(cfg.Server.ListenAddress)
if err != nil {
// If no port specified, assume default
portStr = "443"
}
mgmtPort, _ := strconv.Atoi(portStr)
mgmtSrv := mgmtServer.NewServer(
mgmtConfig,
dnsDomain,
singleAccModeDomain,
mgmtPort,
cfg.Server.MetricsPort,
mgmt.DisableAnonymousMetrics,
mgmt.DisableGeoliteUpdate,
// Always enable user deletion from IDP in combined server (embedded IdP is always enabled)
true,
)
return mgmtSrv, nil
}
// createCombinedHandler creates an HTTP handler that multiplexes Management, Signal (via wsproxy), and Relay WebSocket traffic
func createCombinedHandler(grpcServer *grpc.Server, httpHandler http.Handler, relaySrv *relayServer.Server, meter metric.Meter, cfg *CombinedConfig) http.Handler {
wsProxy := wsproxyserver.New(grpcServer, wsproxyserver.WithOTelMeter(meter))
var relayAcceptFn func(conn net.Conn)
if relaySrv != nil {
relayAcceptFn = relaySrv.RelayAccept()
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
// Native gRPC traffic (HTTP/2 with gRPC content-type)
case r.ProtoMajor == 2 && (strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc") ||
strings.HasPrefix(r.Header.Get("Content-Type"), "application/grpc+proto")):
grpcServer.ServeHTTP(w, r)
// WebSocket proxy for Management gRPC
case r.URL.Path == wsproxy.ProxyPath+wsproxy.ManagementComponent:
wsProxy.Handler().ServeHTTP(w, r)
// WebSocket proxy for Signal gRPC
case r.URL.Path == wsproxy.ProxyPath+wsproxy.SignalComponent:
if cfg.Signal.Enabled {
wsProxy.Handler().ServeHTTP(w, r)
} else {
http.Error(w, "Signal service not enabled", http.StatusNotFound)
}
// Relay WebSocket
case r.URL.Path == "/relay":
if relayAcceptFn != nil {
handleRelayWebSocket(w, r, relayAcceptFn, cfg)
} else {
http.Error(w, "Relay service not enabled", http.StatusNotFound)
}
// Management HTTP API (default)
default:
httpHandler.ServeHTTP(w, r)
}
})
}
// handleRelayWebSocket handles incoming WebSocket connections for the relay service
func handleRelayWebSocket(w http.ResponseWriter, r *http.Request, acceptFn func(conn net.Conn), cfg *CombinedConfig) {
acceptOptions := &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
}
wsConn, err := websocket.Accept(w, r, acceptOptions)
if err != nil {
log.Errorf("failed to accept relay ws connection: %s", err)
return
}
connRemoteAddr := r.RemoteAddr
if r.Header.Get("X-Real-Ip") != "" && r.Header.Get("X-Real-Port") != "" {
connRemoteAddr = net.JoinHostPort(r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port"))
}
rAddr, err := net.ResolveTCPAddr("tcp", connRemoteAddr)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
lAddr, err := net.ResolveTCPAddr("tcp", cfg.Server.ListenAddress)
if err != nil {
_ = wsConn.Close(websocket.StatusInternalError, "internal error")
return
}
log.Debugf("Relay WS client connected from: %s", rAddr)
conn := ws.NewConn(wsConn, lAddr, rAddr)
acceptFn(conn)
}
// logConfig prints all configuration parameters for debugging
func logConfig(cfg *CombinedConfig) {
log.Info("=== Configuration ===")
logServerConfig(cfg)
logComponentsConfig(cfg)
logRelayConfig(cfg)
logManagementConfig(cfg)
log.Info("=== End Configuration ===")
}
func logServerConfig(cfg *CombinedConfig) {
log.Info("--- Server ---")
log.Infof(" Listen address: %s", cfg.Server.ListenAddress)
log.Infof(" Exposed address: %s", cfg.Server.ExposedAddress)
log.Infof(" Healthcheck address: %s", cfg.Server.HealthcheckAddress)
log.Infof(" Metrics port: %d", cfg.Server.MetricsPort)
log.Infof(" Log level: %s", cfg.Server.LogLevel)
log.Infof(" Data dir: %s", cfg.Server.DataDir)
switch {
case cfg.HasTLSCert():
log.Infof(" TLS: cert=%s, key=%s", cfg.Server.TLS.CertFile, cfg.Server.TLS.KeyFile)
case cfg.HasLetsEncrypt():
log.Infof(" TLS: Let's Encrypt (domains=%v)", cfg.Server.TLS.LetsEncrypt.Domains)
default:
log.Info(" TLS: disabled (using reverse proxy)")
}
}
func logComponentsConfig(cfg *CombinedConfig) {
log.Info("--- Components ---")
log.Infof(" Management: %v (log level: %s)", cfg.Management.Enabled, cfg.Management.LogLevel)
log.Infof(" Signal: %v (log level: %s)", cfg.Signal.Enabled, cfg.Signal.LogLevel)
log.Infof(" Relay: %v (log level: %s)", cfg.Relay.Enabled, cfg.Relay.LogLevel)
}
func logRelayConfig(cfg *CombinedConfig) {
if !cfg.Relay.Enabled {
return
}
log.Info("--- Relay ---")
log.Infof(" Exposed address: %s", cfg.Relay.ExposedAddress)
log.Infof(" Auth secret: %s...", maskSecret(cfg.Relay.AuthSecret))
if cfg.Relay.Stun.Enabled {
log.Infof(" STUN ports: %v (log level: %s)", cfg.Relay.Stun.Ports, cfg.Relay.Stun.LogLevel)
} else {
log.Info(" STUN: disabled")
}
}
func logManagementConfig(cfg *CombinedConfig) {
if !cfg.Management.Enabled {
return
}
log.Info("--- Management ---")
log.Infof(" Data dir: %s", cfg.Management.DataDir)
log.Infof(" DNS domain: %s", cfg.Management.DnsDomain)
log.Infof(" Store engine: %s", cfg.Management.Store.Engine)
if cfg.Server.Store.DSN != "" {
log.Infof(" Store DSN: %s", maskDSNPassword(cfg.Server.Store.DSN))
}
log.Info(" Auth (embedded IdP):")
log.Infof(" Issuer: %s", cfg.Management.Auth.Issuer)
log.Infof(" Dashboard redirect URIs: %v", cfg.Management.Auth.DashboardRedirectURIs)
log.Infof(" CLI redirect URIs: %v", cfg.Management.Auth.CLIRedirectURIs)
log.Info(" Client settings:")
log.Infof(" Signal URI: %s", cfg.Management.SignalURI)
for _, s := range cfg.Management.Stuns {
log.Infof(" STUN: %s", s.URI)
}
if len(cfg.Management.Relays.Addresses) > 0 {
log.Infof(" Relay addresses: %v", cfg.Management.Relays.Addresses)
log.Infof(" Relay credentials TTL: %s", cfg.Management.Relays.CredentialsTTL)
}
}
// logEnvVars logs all NB_ environment variables that are currently set
func logEnvVars() {
log.Info("=== Environment Variables ===")
found := false
for _, env := range os.Environ() {
if strings.HasPrefix(env, "NB_") {
key, _, _ := strings.Cut(env, "=")
value := os.Getenv(key)
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
value = maskSecret(value)
}
log.Infof(" %s=%s", key, value)
found = true
}
}
if !found {
log.Info(" (none set)")
}
log.Info("=== End Environment Variables ===")
}
// maskDSNPassword masks the password in a DSN string.
// Handles both key=value format ("password=secret") and URI format ("user:secret@host").
func maskDSNPassword(dsn string) string {
// Key=value format: "host=localhost user=nb password=secret dbname=nb"
if strings.Contains(dsn, "password=") {
parts := strings.Fields(dsn)
for i, p := range parts {
if strings.HasPrefix(p, "password=") {
parts[i] = "password=****"
}
}
return strings.Join(parts, " ")
}
// URI format: "user:password@host..."
if atIdx := strings.Index(dsn, "@"); atIdx != -1 {
prefix := dsn[:atIdx]
if colonIdx := strings.Index(prefix, ":"); colonIdx != -1 {
return prefix[:colonIdx+1] + "****" + dsn[atIdx:]
}
}
return dsn
}
// maskSecret returns first 4 chars of secret followed by "..."
func maskSecret(secret string) string {
if len(secret) <= 4 {
return "****"
}
return secret[:4] + "..."
}

View File

@@ -0,0 +1,111 @@
# NetBird Combined Server Configuration
# Copy this file to config.yaml and customize for your deployment
#
# This is a Management server with optional embedded Signal, Relay, and STUN services.
# By default, all services run locally. You can use external services instead by
# setting the corresponding override fields.
#
# Architecture:
# - Management: Always runs locally (this IS the management server)
# - Signal: Local by default; set 'signalUri' to use external (disables local)
# - Relay: Local by default; set 'relays' to use external (disables local)
# - STUN: Local on port 3478 by default; set 'stuns' to use external instead
server:
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
listenAddress: ":443"
# Public address that peers will use to connect to this server
# Used for relay connections and management DNS domain
# Format: protocol://hostname:port (e.g., https://server.mycompany.com:443)
exposedAddress: "https://server.mycompany.com:443"
# STUN server ports (defaults to [3478] if not specified; set 'stuns' to use external)
# stunPorts:
# - 3478
# Metrics endpoint port
metricsPort: 9090
# Healthcheck endpoint address
healthcheckAddress: ":9000"
# Logging configuration
logLevel: "info" # Default log level for all components: panic, fatal, error, warn, info, debug, trace
logFile: "console" # "console" or path to log file
# TLS configuration (optional)
tls:
certFile: ""
keyFile: ""
letsencrypt:
enabled: false
dataDir: ""
domains: []
email: ""
awsRoute53: false
# Shared secret for relay authentication (required when running local relay)
authSecret: "your-secret-key-here"
# Data directory for all services
dataDir: "/var/lib/netbird/"
# ============================================================================
# External Service Overrides (optional)
# Use these to point to external Signal, Relay, or STUN servers instead of
# running them locally. When set, the corresponding local service is disabled.
# ============================================================================
# External STUN servers - disables local STUN server
# stuns:
# - uri: "stun:stun.example.com:3478"
# - uri: "stun:stun.example.com:3479"
# External relay servers - disables local relay server
# relays:
# addresses:
# - "rels://relay.example.com:443"
# credentialsTTL: "12h"
# secret: "relay-shared-secret"
# External signal server - disables local signal server
# signalUri: "https://signal.example.com:443"
# ============================================================================
# Management Settings
# ============================================================================
# Metrics and updates
disableAnonymousMetrics: false
disableGeoliteUpdate: false
# Embedded authentication/identity provider (Dex) configuration (always enabled)
auth:
# OIDC issuer URL - must be publicly accessible
issuer: "https://server.mycompany.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.netbird.io/nb-auth"
- "https://app.netbird.io/nb-silent-auth"
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# Optional initial admin user
# owner:
# email: "admin@example.com"
# password: "initial-password"
# Store configuration
store:
engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# Reverse proxy settings (optional)
# reverseProxy:
# trustedHTTPProxies: []
# trustedHTTPProxiesCount: 0
# trustedPeers: []

View File

@@ -0,0 +1,115 @@
# Simplified Combined NetBird Server Configuration
# Copy this file to config.yaml and customize for your deployment
# Server-wide settings
server:
# Main HTTP/gRPC port for all services (Management, Signal, Relay)
listenAddress: ":443"
# Metrics endpoint port
metricsPort: 9090
# Healthcheck endpoint address
healthcheckAddress: ":9000"
# Logging configuration
logLevel: "info" # panic, fatal, error, warn, info, debug, trace
logFile: "console" # "console" or path to log file
# TLS configuration (optional)
tls:
certFile: ""
keyFile: ""
letsencrypt:
enabled: false
dataDir: ""
domains: []
email: ""
awsRoute53: false
# Relay service configuration
relay:
# Enable/disable the relay service
enabled: true
# Public address that peers will use to connect to this relay
# Format: hostname:port or ip:port
exposedAddress: "relay.example.com:443"
# Shared secret for relay authentication (required when enabled)
authSecret: "your-secret-key-here"
# Log level for relay (reserved for future use, currently uses global log level)
logLevel: "info"
# Embedded STUN server (optional)
stun:
enabled: false
ports: [3478]
logLevel: "info"
# Signal service configuration
signal:
# Enable/disable the signal service
enabled: true
# Log level for signal (reserved for future use, currently uses global log level)
logLevel: "info"
# Management service configuration
management:
# Enable/disable the management service
enabled: true
# Data directory for management service
dataDir: "/var/lib/netbird/"
# DNS domain for the management server
dnsDomain: ""
# Metrics and updates
disableAnonymousMetrics: false
disableGeoliteUpdate: false
auth:
# OIDC issuer URL - must be publicly accessible
issuer: "https://management.example.com/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
# OAuth2 redirect URIs for dashboard
dashboardRedirectURIs:
- "https://app.example.com/nb-auth"
- "https://app.example.com/nb-silent-auth"
# OAuth2 redirect URIs for CLI
cliRedirectURIs:
- "http://localhost:53000/"
# Optional initial admin user
# owner:
# email: "admin@example.com"
# password: "initial-password"
# External STUN servers (for client config)
stuns: []
# - uri: "stun:stun.example.com:3478"
# External relay servers (for client config)
relays:
addresses: []
# - "rels://relay.example.com:443"
credentialsTTL: "12h"
secret: ""
# External signal server URI (for client config)
signalUri: ""
# Store configuration
store:
engine: "sqlite" # sqlite, postgres, or mysql
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# Reverse proxy settings
reverseProxy:
trustedHTTPProxies: []
trustedHTTPProxiesCount: 0
trustedPeers: []

13
combined/main.go Normal file
View File

@@ -0,0 +1,13 @@
package main
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/combined/cmd"
)
func main() {
if err := cmd.Execute(); err != nil {
log.Fatalf("failed to execute command: %v", err)
}
}

3
go.mod
View File

@@ -243,9 +243,10 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/dtls/v3 v3.0.9 // indirect
github.com/pion/dtls/v3 v3.0.11 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pion/transport/v4 v4.0.1 // indirect
github.com/pion/turn/v4 v4.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect

6
go.sum
View File

@@ -451,8 +451,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
github.com/pion/dtls/v3 v3.0.11 h1:zqn8YhoAU7d9whsWLhNiQlbB8QdpJj8XQVSc5ImUons=
github.com/pion/dtls/v3 v3.0.11/go.mod h1:YEmmBYIoBsY3jmG56dsziTv/Lca9y4Om83370CXfqJ8=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
@@ -470,6 +470,8 @@ github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLh
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o=
github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM=
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=

File diff suppressed because it is too large Load Diff

View File

@@ -55,7 +55,7 @@ var (
// detect whether user specified a port
userPort := cmd.Flag("port").Changed
config, err = loadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
config, err = LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("failed reading provided config file: %s: %v", nbconfig.MgmtConfigPath, err)
}
@@ -133,35 +133,35 @@ var (
}
)
func loadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
func LoadMgmtConfig(ctx context.Context, mgmtConfigPath string) (*nbconfig.Config, error) {
loadedConfig := &nbconfig.Config{}
if _, err := util.ReadJsonWithEnvSub(mgmtConfigPath, loadedConfig); err != nil {
return nil, err
}
applyCommandLineOverrides(loadedConfig)
ApplyCommandLineOverrides(loadedConfig)
// Apply EmbeddedIdP config to HttpConfig if embedded IdP is enabled
err := applyEmbeddedIdPConfig(ctx, loadedConfig)
err := ApplyEmbeddedIdPConfig(ctx, loadedConfig)
if err != nil {
return nil, err
}
if err := applyOIDCConfig(ctx, loadedConfig); err != nil {
if err := ApplyOIDCConfig(ctx, loadedConfig); err != nil {
return nil, err
}
logConfigInfo(loadedConfig)
LogConfigInfo(loadedConfig)
if err := ensureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
if err := EnsureEncryptionKey(ctx, mgmtConfigPath, loadedConfig); err != nil {
return nil, err
}
return loadedConfig, nil
}
// applyCommandLineOverrides applies command-line flag overrides to the config
func applyCommandLineOverrides(cfg *nbconfig.Config) {
// ApplyCommandLineOverrides applies command-line flag overrides to the config
func ApplyCommandLineOverrides(cfg *nbconfig.Config) {
if mgmtLetsencryptDomain != "" {
cfg.HttpConfig.LetsEncryptDomain = mgmtLetsencryptDomain
}
@@ -174,9 +174,9 @@ func applyCommandLineOverrides(cfg *nbconfig.Config) {
}
}
// applyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// ApplyEmbeddedIdPConfig populates HttpConfig and EmbeddedIdP storage from config when embedded IdP is enabled.
// This allows users to only specify EmbeddedIdP config without duplicating values in HttpConfig.
func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
func ApplyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
if cfg.EmbeddedIdP == nil || !cfg.EmbeddedIdP.Enabled {
return nil
}
@@ -222,8 +222,8 @@ func applyEmbeddedIdPConfig(ctx context.Context, cfg *nbconfig.Config) error {
return nil
}
// applyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
// ApplyOIDCConfig fetches and applies OIDC configuration if endpoint is specified
func ApplyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcEndpoint := cfg.HttpConfig.OIDCConfigEndpoint
if oidcEndpoint == "" {
return nil
@@ -249,16 +249,16 @@ func applyOIDCConfig(ctx context.Context, cfg *nbconfig.Config) error {
oidcConfig.JwksURI, cfg.HttpConfig.AuthKeysLocation)
cfg.HttpConfig.AuthKeysLocation = oidcConfig.JwksURI
if err := applyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
if err := ApplyDeviceAuthFlowConfig(ctx, cfg, &oidcConfig, oidcEndpoint); err != nil {
return err
}
applyPKCEFlowConfig(ctx, cfg, &oidcConfig)
ApplyPKCEFlowConfig(ctx, cfg, &oidcConfig)
return nil
}
// applyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
// ApplyDeviceAuthFlowConfig applies OIDC config to DeviceAuthorizationFlow if enabled
func ApplyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse, oidcEndpoint string) error {
if cfg.DeviceAuthorizationFlow == nil || strings.ToLower(cfg.DeviceAuthorizationFlow.Provider) == string(nbconfig.NONE) {
return nil
}
@@ -285,8 +285,8 @@ func applyDeviceAuthFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcCo
return nil
}
// applyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
// ApplyPKCEFlowConfig applies OIDC config to PKCEAuthorizationFlow if configured
func ApplyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *OIDCConfigResponse) {
if cfg.PKCEAuthorizationFlow == nil {
return
}
@@ -299,8 +299,8 @@ func applyPKCEFlowConfig(ctx context.Context, cfg *nbconfig.Config, oidcConfig *
cfg.PKCEAuthorizationFlow.ProviderConfig.AuthorizationEndpoint = oidcConfig.AuthorizationEndpoint
}
// logConfigInfo logs informational messages about the loaded configuration
func logConfigInfo(cfg *nbconfig.Config) {
// LogConfigInfo logs informational messages about the loaded configuration
func LogConfigInfo(cfg *nbconfig.Config) {
if cfg.EmbeddedIdP != nil {
log.Infof("running with the embedded IdP: %v", cfg.EmbeddedIdP.Issuer)
}
@@ -309,8 +309,8 @@ func logConfigInfo(cfg *nbconfig.Config) {
}
}
// ensureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
func ensureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
// EnsureEncryptionKey generates and saves a DataStoreEncryptionKey if not set
func EnsureEncryptionKey(ctx context.Context, configPath string, cfg *nbconfig.Config) error {
if cfg.DataStoreEncryptionKey != "" {
return nil
}

View File

@@ -30,7 +30,7 @@ func Test_loadMgmtConfig(t *testing.T) {
t.Fatalf("failed to create config: %s", err)
}
cfg, err := loadMgmtConfig(context.Background(), tmpFile)
cfg, err := LoadMgmtConfig(context.Background(), tmpFile)
if err != nil {
t.Fatalf("failed to load management config: %s", err)
}

View File

@@ -11,7 +11,6 @@ import (
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/management/server/idp"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
"golang.org/x/crypto/acme/autocert"
@@ -19,6 +18,8 @@ import (
"golang.org/x/net/http2/h2c"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/encryption"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/metrics"
@@ -138,6 +139,14 @@ func (s *BaseServer) Start(ctx context.Context) error {
go metricsWorker.Run(srvCtx)
}
// Run afterInit hooks before starting any servers
// This allows registering additional gRPC services (e.g., Signal) before Serve() is called
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
var compatListener net.Listener
if s.mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
@@ -178,12 +187,6 @@ func (s *BaseServer) Start(ctx context.Context) error {
}
}
for _, fn := range s.afterInit {
if fn != nil {
fn(s)
}
}
log.WithContext(ctx).Infof("management server version %s", version.NetbirdVersion())
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
@@ -255,7 +258,23 @@ func (s *BaseServer) SetContainer(key string, container any) {
log.Tracef("container with key %s set successfully", key)
}
// SetHandlerFunc allows overriding the default HTTP handler function.
// This is useful for multiplexing additional services on the same port.
func (s *BaseServer) SetHandlerFunc(handler http.Handler) {
s.container["customHandler"] = handler
log.Tracef("custom handler set successfully")
}
func (s *BaseServer) handlerFunc(_ context.Context, gRPCHandler *grpc.Server, httpHandler http.Handler, meter metric.Meter) http.Handler {
// Check if a custom handler was set (for multiplexing additional services)
if customHandler, ok := s.GetContainer("customHandler"); ok {
if handler, ok := customHandler.(http.Handler); ok {
log.Tracef("using custom handler")
return handler
}
}
// Use default handler
wsProxy := wsproxyserver.New(gRPCHandler, wsproxyserver.WithOTelMeter(meter))
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {

View File

@@ -18,6 +18,8 @@ import (
"github.com/netbirdio/netbird/management/server/groups"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -368,9 +370,9 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return
}
err = h.permissionsManager.ValidateAccountAccess(r.Context(), accountID, user, false)
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read)
if err != nil {
util.WriteError(r.Context(), status.NewPermissionDeniedError(), w)
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
return
}
@@ -380,9 +382,12 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return
}
// If the user is regular user and does not own the peer
// with the given peerID return an empty list
if !user.HasAdminPower() && !user.IsServiceUser && !userAuth.IsChild {
if !allowed && !userAuth.IsChild {
if account.Settings.RegularUsersViewBlocked {
util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
return
}
peer, ok := account.Peers[peerID]
if !ok {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)

View File

@@ -22,6 +22,8 @@ import (
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -115,6 +117,16 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler {
ctrl2 := gomock.NewController(t)
permissionsManager := permissions.NewMockManager(ctrl2)
permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
permissionsManager.EXPECT().
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)).
DoAndReturn(func(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error) {
user, ok := account.Users[userID]
if !ok {
return false, fmt.Errorf("user not found")
}
return user.HasAdminPower() || user.IsServiceUser, nil
}).
AnyTimes()
return &Handler{
accountManager: &mock_server.MockAccountManager{
@@ -383,12 +395,11 @@ func TestGetAccessiblePeers(t *testing.T) {
UserID: regularUser,
}
p := initTestMetaData(t, peer1, peer2, peer3)
tt := []struct {
name string
peerID string
callerUserID string
viewBlocked bool
expectedStatus int
expectedPeers []string
}{
@@ -427,10 +438,56 @@ func TestGetAccessiblePeers(t *testing.T) {
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
{
name: "regular user gets empty for owned peer list when view blocked",
peerID: "peer1",
callerUserID: regularUser,
viewBlocked: true,
expectedStatus: http.StatusOK,
expectedPeers: []string{},
},
{
name: "regular user gets empty list for unowned peer when view blocked",
peerID: "peer2",
callerUserID: regularUser,
viewBlocked: true,
expectedStatus: http.StatusOK,
expectedPeers: []string{},
},
{
name: "admin user still sees accessible peers when view blocked",
peerID: "peer2",
callerUserID: adminUser,
viewBlocked: true,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer3"},
},
{
name: "service user still sees accessible peers when view blocked",
peerID: "peer3",
callerUserID: serviceUser,
viewBlocked: true,
expectedStatus: http.StatusOK,
expectedPeers: []string{"peer1", "peer2"},
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
p := initTestMetaData(t, peer1, peer2, peer3)
if tc.viewBlocked {
mockAM := p.accountManager.(*mock_server.MockAccountManager)
originalGetAccountByIDFunc := mockAM.GetAccountByIDFunc
mockAM.GetAccountByIDFunc = func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
account, err := originalGetAccountByIDFunc(ctx, accountID, userID)
if err != nil {
return nil, err
}
account.Settings.RegularUsersViewBlocked = true
return account, nil
}
}
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)

View File

@@ -2643,7 +2643,7 @@ func getGormConfig() *gorm.Config {
// newPostgresStore initializes a new Postgres store.
func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics, skipMigration bool) (Store, error) {
dsn, ok := os.LookupEnv(postgresDsnEnv)
dsn, ok := lookupDSNEnv(postgresDsnEnv, postgresDsnEnvLegacy)
if !ok {
return nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
@@ -2652,7 +2652,7 @@ func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics, skipMig
// newMysqlStore initializes a new MySQL store.
func newMysqlStore(ctx context.Context, metrics telemetry.AppMetrics, skipMigration bool) (Store, error) {
dsn, ok := os.LookupEnv(mysqlDsnEnv)
dsn, ok := lookupDSNEnv(mysqlDsnEnv, mysqlDsnEnvLegacy)
if !ok {
return nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
}

View File

@@ -243,10 +243,20 @@ type Store interface {
}
const (
postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
postgresDsnEnv = "NB_STORE_ENGINE_POSTGRES_DSN"
postgresDsnEnvLegacy = "NETBIRD_STORE_ENGINE_POSTGRES_DSN"
mysqlDsnEnv = "NB_STORE_ENGINE_MYSQL_DSN"
mysqlDsnEnvLegacy = "NETBIRD_STORE_ENGINE_MYSQL_DSN"
)
// lookupDSNEnv checks the NB_ env var first, then falls back to the legacy NETBIRD_ env var.
func lookupDSNEnv(nbKey, legacyKey string) (string, bool) {
if v, ok := os.LookupEnv(nbKey); ok {
return v, true
}
return os.LookupEnv(legacyKey)
}
var supportedEngines = []types.Engine{types.SqliteStoreEngine, types.PostgresStoreEngine, types.MysqlStoreEngine}
func getStoreEngineFromEnv() types.Engine {
@@ -531,7 +541,7 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine)
}
func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
dsn, ok := os.LookupEnv(postgresDsnEnv)
dsn, ok := lookupDSNEnv(postgresDsnEnv, postgresDsnEnvLegacy)
if !ok || dsn == "" {
var err error
_, dsn, err = testutil.CreatePostgresTestContainer()
@@ -569,7 +579,7 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng
}
func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine) (*SqlStore, func(), error) {
dsn, ok := os.LookupEnv(mysqlDsnEnv)
dsn, ok := lookupDSNEnv(mysqlDsnEnv, mysqlDsnEnvLegacy)
if !ok || dsn == "" {
var err error
_, dsn, err = testutil.CreateMysqlTestContainer()

View File

@@ -122,6 +122,7 @@ type defaultAppMetrics struct {
Meter metric2.Meter
listener net.Listener
ctx context.Context
externallyManaged bool
idpMetrics *IDPMetrics
httpMiddleware *HTTPMiddleware
grpcMetrics *GRPCMetrics
@@ -171,6 +172,9 @@ func (appMetrics *defaultAppMetrics) Close() error {
// Expose metrics on a given port and endpoint. If endpoint is empty a defaultEndpoint one will be used.
// Exposes metrics in the Prometheus format https://prometheus.io/
func (appMetrics *defaultAppMetrics) Expose(ctx context.Context, port int, endpoint string) error {
if appMetrics.externallyManaged {
return nil
}
if endpoint == "" {
endpoint = defaultEndpoint
}
@@ -252,3 +256,49 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
accountManagerMetrics: accountManagerMetrics,
}, nil
}
// NewAppMetricsWithMeter creates AppMetrics using an externally provided meter.
// The caller is responsible for exposing metrics via HTTP. Expose() and Close() are no-ops.
func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetrics, error) {
idpMetrics, err := NewIDPMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize IDP metrics: %w", err)
}
middleware, err := NewMetricsMiddleware(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize HTTP middleware metrics: %w", err)
}
grpcMetrics, err := NewGRPCMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize gRPC metrics: %w", err)
}
storeMetrics, err := NewStoreMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize store metrics: %w", err)
}
updateChannelMetrics, err := NewUpdateChannelMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize update channel metrics: %w", err)
}
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
}
return &defaultAppMetrics{
Meter: meter,
ctx: ctx,
externallyManaged: true,
idpMetrics: idpMetrics,
httpMiddleware: middleware,
grpcMetrics: grpcMetrics,
storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics,
accountManagerMetrics: accountManagerMetrics,
}, nil
}

View File

@@ -21,8 +21,8 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/stun"
"github.com/netbirdio/netbird/util"
)

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"crypto/tls"
"net"
"net/url"
"sync"
@@ -134,3 +135,10 @@ func (r *Server) ListenerProtocols() []protocol.Protocol {
func (r *Server) InstanceURL() url.URL {
return r.relay.InstanceURL()
}
// RelayAccept returns the relay's Accept function for handling incoming connections.
// This allows external HTTP handlers to route connections to the relay without
// starting the relay's own listeners.
func (r *Server) RelayAccept() func(conn net.Conn) {
return r.relay.Accept
}

View File

@@ -0,0 +1,82 @@
package rest
import (
"context"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// BillingAPI APIs for billing and invoices
type BillingAPI struct {
c *Client
}
// GetUsage retrieves current usage statistics for the account
// See more: https://docs.netbird.io/api/resources/billing#get-current-usage
func (a *BillingAPI) GetUsage(ctx context.Context) (*api.UsageStats, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/usage", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.UsageStats](resp)
return &ret, err
}
// GetSubscription retrieves the current subscription details
// See more: https://docs.netbird.io/api/resources/billing#get-current-subscription
func (a *BillingAPI) GetSubscription(ctx context.Context) (*api.Subscription, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/subscription", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Subscription](resp)
return &ret, err
}
// GetInvoices retrieves the account's paid invoices
// See more: https://docs.netbird.io/api/resources/billing#list-all-invoices
func (a *BillingAPI) GetInvoices(ctx context.Context) ([]api.InvoiceResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/invoices", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.InvoiceResponse](resp)
return ret, err
}
// GetInvoicePDF retrieves the invoice PDF URL
// See more: https://docs.netbird.io/api/resources/billing#get-invoice-pdf
func (a *BillingAPI) GetInvoicePDF(ctx context.Context, invoiceID string) (*api.InvoicePDFResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/invoices/"+invoiceID+"/pdf", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.InvoicePDFResponse](resp)
return &ret, err
}
// GetInvoiceCSV retrieves the invoice CSV content
// See more: https://docs.netbird.io/api/resources/billing#get-invoice-csv
func (a *BillingAPI) GetInvoiceCSV(ctx context.Context, invoiceID string) (string, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/billing/invoices/"+invoiceID+"/csv", nil, nil)
if err != nil {
return "", err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[string](resp)
return ret, err
}

View File

@@ -0,0 +1,194 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testUsageStats = api.UsageStats{
ActiveUsers: 15,
TotalUsers: 20,
ActivePeers: 10,
TotalPeers: 25,
}
testSubscription = api.Subscription{
Active: true,
PlanTier: "basic",
PriceId: "price_1HhxOp",
Currency: "USD",
Price: 1000,
Provider: "stripe",
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
testInvoice = api.InvoiceResponse{
Id: "inv_123",
PeriodStart: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
PeriodEnd: time.Date(2024, 2, 1, 0, 0, 0, 0, time.UTC),
Type: "invoice",
}
testInvoicePDF = api.InvoicePDFResponse{
Url: "https://example.com/invoice.pdf",
}
)
func TestBilling_GetUsage_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/usage", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testUsageStats)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetUsage(context.Background())
require.NoError(t, err)
assert.Equal(t, testUsageStats, *ret)
})
}
func TestBilling_GetUsage_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/usage", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetUsage(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestBilling_GetSubscription_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/subscription", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testSubscription)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetSubscription(context.Background())
require.NoError(t, err)
assert.Equal(t, testSubscription, *ret)
})
}
func TestBilling_GetSubscription_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/subscription", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetSubscription(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestBilling_GetInvoices_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/invoices", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.InvoiceResponse{testInvoice})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetInvoices(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testInvoice, ret[0])
})
}
func TestBilling_GetInvoices_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/invoices", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetInvoices(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestBilling_GetInvoicePDF_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/invoices/inv_123/pdf", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testInvoicePDF)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetInvoicePDF(context.Background(), "inv_123")
require.NoError(t, err)
assert.Equal(t, testInvoicePDF, *ret)
})
}
func TestBilling_GetInvoicePDF_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/invoices/inv_123/pdf", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetInvoicePDF(context.Background(), "inv_123")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestBilling_GetInvoiceCSV_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/invoices/inv_123/csv", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal("col1,col2\nval1,val2")
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetInvoiceCSV(context.Background(), "inv_123")
require.NoError(t, err)
assert.Equal(t, "col1,col2\nval1,val2", ret)
})
}
func TestBilling_GetInvoiceCSV_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/billing/invoices/inv_123/csv", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Billing.GetInvoiceCSV(context.Background(), "inv_123")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Empty(t, ret)
})
}

View File

@@ -73,6 +73,38 @@ type Client struct {
// Events NetBird Events APIs
// see more: https://docs.netbird.io/api/resources/events
Events *EventsAPI
// Billing NetBird Billing APIs for subscriptions, plans, and invoices
// see more: https://docs.netbird.io/api/resources/billing
Billing *BillingAPI
// MSP NetBird MSP tenant management APIs
// see more: https://docs.netbird.io/api/resources/msp
MSP *MSPAPI
// EDR NetBird EDR integration APIs (Intune, SentinelOne, Falcon, Huntress)
// see more: https://docs.netbird.io/api/resources/edr
EDR *EDRAPI
// SCIM NetBird SCIM IDP integration APIs
// see more: https://docs.netbird.io/api/resources/scim
SCIM *SCIMAPI
// EventStreaming NetBird Event Streaming integration APIs
// see more: https://docs.netbird.io/api/resources/event-streaming
EventStreaming *EventStreamingAPI
// IdentityProviders NetBird Identity Providers APIs
// see more: https://docs.netbird.io/api/resources/identity-providers
IdentityProviders *IdentityProvidersAPI
// Ingress NetBird Ingress Peers APIs
// see more: https://docs.netbird.io/api/resources/ingress-ports
Ingress *IngressAPI
// Instance NetBird Instance API
// see more: https://docs.netbird.io/api/resources/instance
Instance *InstanceAPI
}
// New initialize new Client instance using PAT token
@@ -120,6 +152,14 @@ func (c *Client) initialize() {
c.DNSZones = &DNSZonesAPI{c}
c.GeoLocation = &GeoLocationAPI{c}
c.Events = &EventsAPI{c}
c.Billing = &BillingAPI{c}
c.MSP = &MSPAPI{c}
c.EDR = &EDRAPI{c}
c.SCIM = &SCIMAPI{c}
c.EventStreaming = &EventStreamingAPI{c}
c.IdentityProviders = &IdentityProvidersAPI{c}
c.Ingress = &IngressAPI{c}
c.Instance = &InstanceAPI{c}
}
// NewRequest creates and executes new management API request

View File

@@ -0,0 +1,307 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// EDRAPI APIs for EDR integrations (Intune, SentinelOne, Falcon, Huntress)
type EDRAPI struct {
c *Client
}
// GetIntuneIntegration retrieves the EDR Intune integration
// See more: https://docs.netbird.io/api/resources/edr#get-intune-integration
func (a *EDRAPI) GetIntuneIntegration(ctx context.Context) (*api.EDRIntuneResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/intune", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRIntuneResponse](resp)
return &ret, err
}
// CreateIntuneIntegration creates a new EDR Intune integration
// See more: https://docs.netbird.io/api/resources/edr#create-intune-integration
func (a *EDRAPI) CreateIntuneIntegration(ctx context.Context, request api.EDRIntuneRequest) (*api.EDRIntuneResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/intune", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRIntuneResponse](resp)
return &ret, err
}
// UpdateIntuneIntegration updates an existing EDR Intune integration
// See more: https://docs.netbird.io/api/resources/edr#update-intune-integration
func (a *EDRAPI) UpdateIntuneIntegration(ctx context.Context, request api.EDRIntuneRequest) (*api.EDRIntuneResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/intune", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRIntuneResponse](resp)
return &ret, err
}
// DeleteIntuneIntegration deletes the EDR Intune integration
// See more: https://docs.netbird.io/api/resources/edr#delete-intune-integration
func (a *EDRAPI) DeleteIntuneIntegration(ctx context.Context) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/intune", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// GetSentinelOneIntegration retrieves the EDR SentinelOne integration
// See more: https://docs.netbird.io/api/resources/edr#get-sentinelone-integration
func (a *EDRAPI) GetSentinelOneIntegration(ctx context.Context) (*api.EDRSentinelOneResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/sentinelone", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRSentinelOneResponse](resp)
return &ret, err
}
// CreateSentinelOneIntegration creates a new EDR SentinelOne integration
// See more: https://docs.netbird.io/api/resources/edr#create-sentinelone-integration
func (a *EDRAPI) CreateSentinelOneIntegration(ctx context.Context, request api.EDRSentinelOneRequest) (*api.EDRSentinelOneResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/sentinelone", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRSentinelOneResponse](resp)
return &ret, err
}
// UpdateSentinelOneIntegration updates an existing EDR SentinelOne integration
// See more: https://docs.netbird.io/api/resources/edr#update-sentinelone-integration
func (a *EDRAPI) UpdateSentinelOneIntegration(ctx context.Context, request api.EDRSentinelOneRequest) (*api.EDRSentinelOneResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/sentinelone", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRSentinelOneResponse](resp)
return &ret, err
}
// DeleteSentinelOneIntegration deletes the EDR SentinelOne integration
// See more: https://docs.netbird.io/api/resources/edr#delete-sentinelone-integration
func (a *EDRAPI) DeleteSentinelOneIntegration(ctx context.Context) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/sentinelone", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// GetFalconIntegration retrieves the EDR Falcon integration
// See more: https://docs.netbird.io/api/resources/edr#get-falcon-integration
func (a *EDRAPI) GetFalconIntegration(ctx context.Context) (*api.EDRFalconResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/falcon", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRFalconResponse](resp)
return &ret, err
}
// CreateFalconIntegration creates a new EDR Falcon integration
// See more: https://docs.netbird.io/api/resources/edr#create-falcon-integration
func (a *EDRAPI) CreateFalconIntegration(ctx context.Context, request api.EDRFalconRequest) (*api.EDRFalconResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/falcon", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRFalconResponse](resp)
return &ret, err
}
// UpdateFalconIntegration updates an existing EDR Falcon integration
// See more: https://docs.netbird.io/api/resources/edr#update-falcon-integration
func (a *EDRAPI) UpdateFalconIntegration(ctx context.Context, request api.EDRFalconRequest) (*api.EDRFalconResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/falcon", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRFalconResponse](resp)
return &ret, err
}
// DeleteFalconIntegration deletes the EDR Falcon integration
// See more: https://docs.netbird.io/api/resources/edr#delete-falcon-integration
func (a *EDRAPI) DeleteFalconIntegration(ctx context.Context) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/falcon", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// GetHuntressIntegration retrieves the EDR Huntress integration
// See more: https://docs.netbird.io/api/resources/edr#get-huntress-integration
func (a *EDRAPI) GetHuntressIntegration(ctx context.Context) (*api.EDRHuntressResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/huntress", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRHuntressResponse](resp)
return &ret, err
}
// CreateHuntressIntegration creates a new EDR Huntress integration
// See more: https://docs.netbird.io/api/resources/edr#create-huntress-integration
func (a *EDRAPI) CreateHuntressIntegration(ctx context.Context, request api.EDRHuntressRequest) (*api.EDRHuntressResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/huntress", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRHuntressResponse](resp)
return &ret, err
}
// UpdateHuntressIntegration updates an existing EDR Huntress integration
// See more: https://docs.netbird.io/api/resources/edr#update-huntress-integration
func (a *EDRAPI) UpdateHuntressIntegration(ctx context.Context, request api.EDRHuntressRequest) (*api.EDRHuntressResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/huntress", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRHuntressResponse](resp)
return &ret, err
}
// DeleteHuntressIntegration deletes the EDR Huntress integration
// See more: https://docs.netbird.io/api/resources/edr#delete-huntress-integration
func (a *EDRAPI) DeleteHuntressIntegration(ctx context.Context) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/huntress", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// BypassPeerCompliance bypasses compliance for a non-compliant peer
// See more: https://docs.netbird.io/api/resources/edr#bypass-peer-compliance
func (a *EDRAPI) BypassPeerCompliance(ctx context.Context, peerID string) (*api.BypassResponse, error) {
resp, err := a.c.NewRequest(ctx, "POST", "/api/peers/"+peerID+"/edr/bypass", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.BypassResponse](resp)
return &ret, err
}
// RevokePeerBypass revokes the compliance bypass for a peer
// See more: https://docs.netbird.io/api/resources/edr#revoke-peer-bypass
func (a *EDRAPI) RevokePeerBypass(ctx context.Context, peerID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+peerID+"/edr/bypass", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// ListBypassedPeers returns all peers that have compliance bypassed
// See more: https://docs.netbird.io/api/resources/edr#list-all-bypassed-peers
func (a *EDRAPI) ListBypassedPeers(ctx context.Context) ([]api.BypassResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/edr/bypassed", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.BypassResponse](resp)
return ret, err
}

View File

@@ -0,0 +1,422 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testIntuneResponse = api.EDRIntuneResponse{
AccountId: "acc-1",
ClientId: "client-1",
TenantId: "tenant-1",
Enabled: true,
Id: 1,
Groups: []api.Group{},
LastSyncedInterval: 24,
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
CreatedBy: "user-1",
}
testSentinelOneResponse = api.EDRSentinelOneResponse{
AccountId: "acc-1",
ApiUrl: "https://sentinelone.example.com",
Enabled: true,
Id: 2,
Groups: []api.Group{},
LastSyncedInterval: 24,
MatchAttributes: api.SentinelOneMatchAttributes{},
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
CreatedBy: "user-1",
}
testFalconResponse = api.EDRFalconResponse{
AccountId: "acc-1",
CloudId: "us-1",
Enabled: true,
Id: 3,
Groups: []api.Group{},
ZtaScoreThreshold: 50,
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
CreatedBy: "user-1",
}
testHuntressResponse = api.EDRHuntressResponse{
AccountId: "acc-1",
Enabled: true,
Id: 4,
Groups: []api.Group{},
LastSyncedInterval: 24,
MatchAttributes: api.HuntressMatchAttributes{},
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
CreatedBy: "user-1",
}
testBypassResponse = api.BypassResponse{
PeerId: "peer-1",
}
)
// Intune tests
func TestEDR_GetIntuneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testIntuneResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.GetIntuneIntegration(context.Background())
require.NoError(t, err)
assert.Equal(t, testIntuneResponse, *ret)
})
}
func TestEDR_GetIntuneIntegration_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.GetIntuneIntegration(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestEDR_CreateIntuneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.EDRIntuneRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "client-1", req.ClientId)
retBytes, _ := json.Marshal(testIntuneResponse)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.CreateIntuneIntegration(context.Background(), api.EDRIntuneRequest{
ClientId: "client-1",
Secret: "secret",
TenantId: "tenant-1",
Groups: []string{"group-1"},
LastSyncedInterval: 24,
})
require.NoError(t, err)
assert.Equal(t, testIntuneResponse, *ret)
})
}
func TestEDR_CreateIntuneIntegration_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.CreateIntuneIntegration(context.Background(), api.EDRIntuneRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestEDR_UpdateIntuneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
retBytes, _ := json.Marshal(testIntuneResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.UpdateIntuneIntegration(context.Background(), api.EDRIntuneRequest{
ClientId: "client-1",
Secret: "new-secret",
TenantId: "tenant-1",
Groups: []string{"group-1"},
})
require.NoError(t, err)
assert.Equal(t, testIntuneResponse, *ret)
})
}
func TestEDR_DeleteIntuneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.EDR.DeleteIntuneIntegration(context.Background())
require.NoError(t, err)
})
}
func TestEDR_DeleteIntuneIntegration_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/intune", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.EDR.DeleteIntuneIntegration(context.Background())
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
// SentinelOne tests
func TestEDR_GetSentinelOneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/sentinelone", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testSentinelOneResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.GetSentinelOneIntegration(context.Background())
require.NoError(t, err)
assert.Equal(t, testSentinelOneResponse, *ret)
})
}
func TestEDR_CreateSentinelOneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/sentinelone", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testSentinelOneResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.CreateSentinelOneIntegration(context.Background(), api.EDRSentinelOneRequest{
ApiToken: "token",
ApiUrl: "https://sentinelone.example.com",
Groups: []string{"group-1"},
LastSyncedInterval: 24,
MatchAttributes: api.SentinelOneMatchAttributes{},
})
require.NoError(t, err)
assert.Equal(t, testSentinelOneResponse, *ret)
})
}
func TestEDR_DeleteSentinelOneIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/sentinelone", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.EDR.DeleteSentinelOneIntegration(context.Background())
require.NoError(t, err)
})
}
// Falcon tests
func TestEDR_GetFalconIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/falcon", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testFalconResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.GetFalconIntegration(context.Background())
require.NoError(t, err)
assert.Equal(t, testFalconResponse, *ret)
})
}
func TestEDR_CreateFalconIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/falcon", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testFalconResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.CreateFalconIntegration(context.Background(), api.EDRFalconRequest{
ClientId: "client-1",
Secret: "secret",
CloudId: "us-1",
Groups: []string{"group-1"},
ZtaScoreThreshold: 50,
})
require.NoError(t, err)
assert.Equal(t, testFalconResponse, *ret)
})
}
func TestEDR_DeleteFalconIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/falcon", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.EDR.DeleteFalconIntegration(context.Background())
require.NoError(t, err)
})
}
// Huntress tests
func TestEDR_GetHuntressIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/huntress", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testHuntressResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.GetHuntressIntegration(context.Background())
require.NoError(t, err)
assert.Equal(t, testHuntressResponse, *ret)
})
}
func TestEDR_CreateHuntressIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/huntress", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testHuntressResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.CreateHuntressIntegration(context.Background(), api.EDRHuntressRequest{
ApiKey: "key",
ApiSecret: "secret",
Groups: []string{"group-1"},
LastSyncedInterval: 24,
MatchAttributes: api.HuntressMatchAttributes{},
})
require.NoError(t, err)
assert.Equal(t, testHuntressResponse, *ret)
})
}
func TestEDR_DeleteHuntressIntegration_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/edr/huntress", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.EDR.DeleteHuntressIntegration(context.Background())
require.NoError(t, err)
})
}
// Peer bypass tests
func TestEDR_BypassPeerCompliance_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/peer-1/edr/bypass", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testBypassResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.BypassPeerCompliance(context.Background(), "peer-1")
require.NoError(t, err)
assert.Equal(t, testBypassResponse, *ret)
})
}
func TestEDR_BypassPeerCompliance_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/peer-1/edr/bypass", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Bad request", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.BypassPeerCompliance(context.Background(), "peer-1")
assert.Error(t, err)
assert.Equal(t, "Bad request", err.Error())
assert.Nil(t, ret)
})
}
func TestEDR_RevokePeerBypass_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/peer-1/edr/bypass", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.EDR.RevokePeerBypass(context.Background(), "peer-1")
require.NoError(t, err)
})
}
func TestEDR_RevokePeerBypass_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/peer-1/edr/bypass", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.EDR.RevokePeerBypass(context.Background(), "peer-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestEDR_ListBypassedPeers_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/edr/bypassed", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.BypassResponse{testBypassResponse})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.ListBypassedPeers(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testBypassResponse, ret[0])
})
}
func TestEDR_ListBypassedPeers_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/edr/bypassed", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EDR.ListBypassedPeers(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}

View File

@@ -0,0 +1,92 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"strconv"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// EventStreamingAPI APIs for event streaming integrations
type EventStreamingAPI struct {
c *Client
}
// List retrieves all event streaming integrations
// See more: https://docs.netbird.io/api/resources/event-streaming#list-all-event-streaming-integrations
func (a *EventStreamingAPI) List(ctx context.Context) ([]api.IntegrationResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/event-streaming", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IntegrationResponse](resp)
return ret, err
}
// Get retrieves a specific event streaming integration by ID
// See more: https://docs.netbird.io/api/resources/event-streaming#retrieve-an-event-streaming-integration
func (a *EventStreamingAPI) Get(ctx context.Context, integrationID int) (*api.IntegrationResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/event-streaming/"+strconv.Itoa(integrationID), nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IntegrationResponse](resp)
return &ret, err
}
// Create creates a new event streaming integration
// See more: https://docs.netbird.io/api/resources/event-streaming#create-an-event-streaming-integration
func (a *EventStreamingAPI) Create(ctx context.Context, request api.CreateIntegrationRequest) (*api.IntegrationResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/event-streaming", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IntegrationResponse](resp)
return &ret, err
}
// Update updates an existing event streaming integration
// See more: https://docs.netbird.io/api/resources/event-streaming#update-an-event-streaming-integration
func (a *EventStreamingAPI) Update(ctx context.Context, integrationID int, request api.CreateIntegrationRequest) (*api.IntegrationResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/event-streaming/"+strconv.Itoa(integrationID), bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IntegrationResponse](resp)
return &ret, err
}
// Delete deletes an event streaming integration
// See more: https://docs.netbird.io/api/resources/event-streaming#delete-an-event-streaming-integration
func (a *EventStreamingAPI) Delete(ctx context.Context, integrationID int) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/event-streaming/"+strconv.Itoa(integrationID), nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,194 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testIntegrationResponse = api.IntegrationResponse{
Id: ptr[int64](1),
AccountId: ptr("acc-1"),
Platform: (*api.IntegrationResponsePlatform)(ptr("datadog")),
Enabled: ptr(true),
Config: &map[string]string{"api_key": "****"},
CreatedAt: ptr(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)),
UpdatedAt: ptr(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)),
}
)
func TestEventStreaming_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.IntegrationResponse{testIntegrationResponse})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testIntegrationResponse, ret[0])
})
}
func TestEventStreaming_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestEventStreaming_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testIntegrationResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.Get(context.Background(), 1)
require.NoError(t, err)
assert.Equal(t, testIntegrationResponse, *ret)
})
}
func TestEventStreaming_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.Get(context.Background(), 1)
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestEventStreaming_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.CreateIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, api.CreateIntegrationRequestPlatformDatadog, req.Platform)
assert.Equal(t, true, req.Enabled)
retBytes, _ := json.Marshal(testIntegrationResponse)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.Create(context.Background(), api.CreateIntegrationRequest{
Platform: api.CreateIntegrationRequestPlatformDatadog,
Enabled: true,
Config: map[string]string{"api_key": "test-key"},
})
require.NoError(t, err)
assert.Equal(t, testIntegrationResponse, *ret)
})
}
func TestEventStreaming_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.Create(context.Background(), api.CreateIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestEventStreaming_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.CreateIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, false, req.Enabled)
retBytes, _ := json.Marshal(testIntegrationResponse)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.Update(context.Background(), 1, api.CreateIntegrationRequest{
Platform: api.CreateIntegrationRequestPlatformDatadog,
Enabled: false,
Config: map[string]string{"api_key": "updated-key"},
})
require.NoError(t, err)
assert.Equal(t, testIntegrationResponse, *ret)
})
}
func TestEventStreaming_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.EventStreaming.Update(context.Background(), 1, api.CreateIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestEventStreaming_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.EventStreaming.Delete(context.Background(), 1)
require.NoError(t, err)
})
}
func TestEventStreaming_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/event-streaming/1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.EventStreaming.Delete(context.Background(), 1)
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}

View File

@@ -2,6 +2,8 @@ package rest
import (
"context"
"fmt"
"time"
"github.com/netbirdio/netbird/shared/management/http/api"
)
@@ -11,10 +13,79 @@ type EventsAPI struct {
c *Client
}
// List list all events
// See more: https://docs.netbird.io/api/resources/events#list-all-events
func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil, nil)
// NetworkTrafficOption options for ListNetworkTrafficEvents API
type NetworkTrafficOption func(query map[string]string)
func NetworkTrafficPage(page int) NetworkTrafficOption {
return func(query map[string]string) {
query["page"] = fmt.Sprintf("%d", page)
}
}
func NetworkTrafficPageSize(pageSize int) NetworkTrafficOption {
return func(query map[string]string) {
query["page_size"] = fmt.Sprintf("%d", pageSize)
}
}
func NetworkTrafficUserID(userID string) NetworkTrafficOption {
return func(query map[string]string) {
query["user_id"] = userID
}
}
func NetworkTrafficReporterID(reporterID string) NetworkTrafficOption {
return func(query map[string]string) {
query["reporter_id"] = reporterID
}
}
func NetworkTrafficProtocol(protocol int) NetworkTrafficOption {
return func(query map[string]string) {
query["protocol"] = fmt.Sprintf("%d", protocol)
}
}
func NetworkTrafficType(t api.GetApiEventsNetworkTrafficParamsType) NetworkTrafficOption {
return func(query map[string]string) {
query["type"] = string(t)
}
}
func NetworkTrafficConnectionType(ct api.GetApiEventsNetworkTrafficParamsConnectionType) NetworkTrafficOption {
return func(query map[string]string) {
query["connection_type"] = string(ct)
}
}
func NetworkTrafficDirection(d api.GetApiEventsNetworkTrafficParamsDirection) NetworkTrafficOption {
return func(query map[string]string) {
query["direction"] = string(d)
}
}
func NetworkTrafficSearch(search string) NetworkTrafficOption {
return func(query map[string]string) {
query["search"] = search
}
}
func NetworkTrafficStartDate(t time.Time) NetworkTrafficOption {
return func(query map[string]string) {
query["start_date"] = t.Format(time.RFC3339)
}
}
func NetworkTrafficEndDate(t time.Time) NetworkTrafficOption {
return func(query map[string]string) {
query["end_date"] = t.Format(time.RFC3339)
}
}
// ListAuditEvents list all audit events
// See more: https://docs.netbird.io/api/resources/events#list-all-audit-events
func (a *EventsAPI) ListAuditEvents(ctx context.Context) ([]api.Event, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/events/audit", nil, nil)
if err != nil {
return nil, err
}
@@ -24,3 +95,21 @@ func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
ret, err := parseResponse[[]api.Event](resp)
return ret, err
}
// ListNetworkTrafficEvents list network traffic events
// See more: https://docs.netbird.io/api/resources/events#list-network-traffic-events
func (a *EventsAPI) ListNetworkTrafficEvents(ctx context.Context, opts ...NetworkTrafficOption) (*api.NetworkTrafficEventsResponse, error) {
query := make(map[string]string)
for _, o := range opts {
o(query)
}
resp, err := a.c.NewRequest(ctx, "GET", "/api/events/network-traffic", nil, query)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkTrafficEventsResponse](resp)
return &ret, err
}

View File

@@ -21,37 +21,76 @@ var (
Activity: "AccountCreate",
ActivityCode: api.EventActivityCodeAccountCreate,
}
testNetworkTrafficResponse = api.NetworkTrafficEventsResponse{
Data: []api.NetworkTrafficEvent{},
Page: 1,
PageSize: 50,
}
)
func TestEvents_List_200(t *testing.T) {
func TestEvents_ListAuditEvents_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/api/events/audit", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Event{testEvent})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Events.List(context.Background())
ret, err := c.Events.ListAuditEvents(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testEvent, ret[0])
})
}
func TestEvents_List_Err(t *testing.T) {
func TestEvents_ListAuditEvents_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/api/events/audit", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Events.List(context.Background())
ret, err := c.Events.ListAuditEvents(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestEvents_ListNetworkTrafficEvents_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events/network-traffic", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "1", r.URL.Query().Get("page"))
assert.Equal(t, "50", r.URL.Query().Get("page_size"))
retBytes, _ := json.Marshal(testNetworkTrafficResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Events.ListNetworkTrafficEvents(context.Background(),
rest.NetworkTrafficPage(1),
rest.NetworkTrafficPageSize(50),
)
require.NoError(t, err)
assert.Equal(t, testNetworkTrafficResponse, *ret)
})
}
func TestEvents_ListNetworkTrafficEvents_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events/network-traffic", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Events.ListNetworkTrafficEvents(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestEvents_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
// Do something that would trigger any event
@@ -62,7 +101,7 @@ func TestEvents_Integration(t *testing.T) {
})
require.NoError(t, err)
events, err := c.Events.List(context.Background())
events, err := c.Events.ListAuditEvents(context.Background())
require.NoError(t, err)
assert.NotEmpty(t, events)
})

View File

@@ -0,0 +1,92 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// IdentityProvidersAPI APIs for Identity Providers, do not use directly
type IdentityProvidersAPI struct {
c *Client
}
// List all identity providers
// See more: https://docs.netbird.io/api/resources/identity-providers#list-all-identity-providers
func (a *IdentityProvidersAPI) List(ctx context.Context) ([]api.IdentityProvider, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/identity-providers", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IdentityProvider](resp)
return ret, err
}
// Get identity provider info
// See more: https://docs.netbird.io/api/resources/identity-providers#retrieve-an-identity-provider
func (a *IdentityProvidersAPI) Get(ctx context.Context, idpID string) (*api.IdentityProvider, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/identity-providers/"+idpID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IdentityProvider](resp)
return &ret, err
}
// Create new identity provider
// See more: https://docs.netbird.io/api/resources/identity-providers#create-an-identity-provider
func (a *IdentityProvidersAPI) Create(ctx context.Context, request api.PostApiIdentityProvidersJSONRequestBody) (*api.IdentityProvider, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/identity-providers", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IdentityProvider](resp)
return &ret, err
}
// Update update identity provider
// See more: https://docs.netbird.io/api/resources/identity-providers#update-an-identity-provider
func (a *IdentityProvidersAPI) Update(ctx context.Context, idpID string, request api.PutApiIdentityProvidersIdpIdJSONRequestBody) (*api.IdentityProvider, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/identity-providers/"+idpID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IdentityProvider](resp)
return &ret, err
}
// Delete delete identity provider
// See more: https://docs.netbird.io/api/resources/identity-providers#delete-an-identity-provider
func (a *IdentityProvidersAPI) Delete(ctx context.Context, idpID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/identity-providers/"+idpID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,183 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var testIdentityProvider = api.IdentityProvider{
ClientId: "test-client-id",
Id: ptr("Test"),
}
func TestIdentityProviders_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.IdentityProvider{testIdentityProvider})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testIdentityProvider, ret[0])
})
}
func TestIdentityProviders_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestIdentityProviders_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testIdentityProvider)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testIdentityProvider, *ret)
})
}
func TestIdentityProviders_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestIdentityProviders_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiIdentityProvidersJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "new-client-id", req.ClientId)
retBytes, _ := json.Marshal(testIdentityProvider)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.Create(context.Background(), api.PostApiIdentityProvidersJSONRequestBody{
ClientId: "new-client-id",
})
require.NoError(t, err)
assert.Equal(t, testIdentityProvider, *ret)
})
}
func TestIdentityProviders_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.Create(context.Background(), api.PostApiIdentityProvidersJSONRequestBody{
ClientId: "new-client-id",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestIdentityProviders_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiIdentityProvidersIdpIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "updated-client-id", req.ClientId)
retBytes, _ := json.Marshal(testIdentityProvider)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.Update(context.Background(), "Test", api.PutApiIdentityProvidersIdpIdJSONRequestBody{
ClientId: "updated-client-id",
})
require.NoError(t, err)
assert.Equal(t, testIdentityProvider, *ret)
})
}
func TestIdentityProviders_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.IdentityProviders.Update(context.Background(), "Test", api.PutApiIdentityProvidersIdpIdJSONRequestBody{
ClientId: "updated-client-id",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestIdentityProviders_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.IdentityProviders.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestIdentityProviders_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/identity-providers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.IdentityProviders.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}

View File

@@ -0,0 +1,92 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// IngressAPI APIs for Ingress Peers, do not use directly
type IngressAPI struct {
c *Client
}
// List all ingress peers
// See more: https://docs.netbird.io/api/resources/ingress#list-all-ingress-peers
func (a *IngressAPI) List(ctx context.Context) ([]api.IngressPeer, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/ingress/peers", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IngressPeer](resp)
return ret, err
}
// Get ingress peer info
// See more: https://docs.netbird.io/api/resources/ingress#retrieve-an-ingress-peer
func (a *IngressAPI) Get(ctx context.Context, ingressPeerID string) (*api.IngressPeer, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/ingress/peers/"+ingressPeerID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IngressPeer](resp)
return &ret, err
}
// Create new ingress peer
// See more: https://docs.netbird.io/api/resources/ingress#create-an-ingress-peer
func (a *IngressAPI) Create(ctx context.Context, request api.PostApiIngressPeersJSONRequestBody) (*api.IngressPeer, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/ingress/peers", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IngressPeer](resp)
return &ret, err
}
// Update update ingress peer
// See more: https://docs.netbird.io/api/resources/ingress#update-an-ingress-peer
func (a *IngressAPI) Update(ctx context.Context, ingressPeerID string, request api.PutApiIngressPeersIngressPeerIdJSONRequestBody) (*api.IngressPeer, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/ingress/peers/"+ingressPeerID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IngressPeer](resp)
return &ret, err
}
// Delete delete ingress peer
// See more: https://docs.netbird.io/api/resources/ingress#delete-an-ingress-peer
func (a *IngressAPI) Delete(ctx context.Context, ingressPeerID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/ingress/peers/"+ingressPeerID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,184 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var testIngressPeer = api.IngressPeer{
Connected: true,
Enabled: true,
Id: "Test",
}
func TestIngress_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.IngressPeer{testIngressPeer})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testIngressPeer, ret[0])
})
}
func TestIngress_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestIngress_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testIngressPeer)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testIngressPeer, *ret)
})
}
func TestIngress_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestIngress_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiIngressPeersJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "peer-id", req.PeerId)
retBytes, _ := json.Marshal(testIngressPeer)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.Create(context.Background(), api.PostApiIngressPeersJSONRequestBody{
PeerId: "peer-id",
})
require.NoError(t, err)
assert.Equal(t, testIngressPeer, *ret)
})
}
func TestIngress_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.Create(context.Background(), api.PostApiIngressPeersJSONRequestBody{
PeerId: "peer-id",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestIngress_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiIngressPeersIngressPeerIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, req.Enabled)
retBytes, _ := json.Marshal(testIngressPeer)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.Update(context.Background(), "Test", api.PutApiIngressPeersIngressPeerIdJSONRequestBody{
Enabled: true,
})
require.NoError(t, err)
assert.Equal(t, testIngressPeer, *ret)
})
}
func TestIngress_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Ingress.Update(context.Background(), "Test", api.PutApiIngressPeersIngressPeerIdJSONRequestBody{
Enabled: true,
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestIngress_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Ingress.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestIngress_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/ingress/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Ingress.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}

View File

@@ -0,0 +1,46 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// InstanceAPI APIs for Instance status and version, do not use directly
type InstanceAPI struct {
c *Client
}
// GetStatus get instance status
// See more: https://docs.netbird.io/api/resources/instance#get-instance-status
func (a *InstanceAPI) GetStatus(ctx context.Context) (*api.InstanceStatus, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/instance", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.InstanceStatus](resp)
return &ret, err
}
// Setup perform initial instance setup
// See more: https://docs.netbird.io/api/resources/instance#setup-instance
func (a *InstanceAPI) Setup(ctx context.Context, request api.PostApiSetupJSONRequestBody) (*api.SetupResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/setup", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.SetupResponse](resp)
return &ret, err
}

View File

@@ -0,0 +1,96 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testInstanceStatus = api.InstanceStatus{
SetupRequired: true,
}
testSetupResponse = api.SetupResponse{
Email: "admin@example.com",
UserId: "user-123",
}
)
func TestInstance_GetStatus_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/instance", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testInstanceStatus)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Instance.GetStatus(context.Background())
require.NoError(t, err)
assert.Equal(t, testInstanceStatus, *ret)
})
}
func TestInstance_GetStatus_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/instance", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Instance.GetStatus(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestInstance_Setup_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiSetupJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "admin@example.com", req.Email)
retBytes, _ := json.Marshal(testSetupResponse)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Instance.Setup(context.Background(), api.PostApiSetupJSONRequestBody{
Email: "admin@example.com",
})
require.NoError(t, err)
assert.Equal(t, testSetupResponse, *ret)
})
}
func TestInstance_Setup_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Instance.Setup(context.Background(), api.PostApiSetupJSONRequestBody{
Email: "admin@example.com",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}

View File

@@ -0,0 +1,122 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// MSPAPI APIs for MSP tenant management
type MSPAPI struct {
c *Client
}
// ListTenants retrieves all MSP tenants
// See more: https://docs.netbird.io/api/resources/msp#list-all-tenants
func (a *MSPAPI) ListTenants(ctx context.Context) (*api.GetTenantsResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/msp/tenants", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.GetTenantsResponse](resp)
return &ret, err
}
// CreateTenant creates a new MSP tenant
// See more: https://docs.netbird.io/api/resources/msp#create-a-tenant
func (a *MSPAPI) CreateTenant(ctx context.Context, request api.CreateTenantRequest) (*api.TenantResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/msp/tenants", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.TenantResponse](resp)
return &ret, err
}
// UpdateTenant updates an existing MSP tenant
// See more: https://docs.netbird.io/api/resources/msp#update-a-tenant
func (a *MSPAPI) UpdateTenant(ctx context.Context, tenantID string, request api.UpdateTenantRequest) (*api.TenantResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/msp/tenants/"+tenantID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.TenantResponse](resp)
return &ret, err
}
// DeleteTenant deletes an MSP tenant
// See more: https://docs.netbird.io/api/resources/msp#delete-a-tenant
func (a *MSPAPI) DeleteTenant(ctx context.Context, tenantID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/msp/tenants/"+tenantID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// UnlinkTenant unlinks a tenant from the MSP account
// See more: https://docs.netbird.io/api/resources/msp#unlink-a-tenant
func (a *MSPAPI) UnlinkTenant(ctx context.Context, tenantID, owner string) error {
params := map[string]string{"owner": owner}
requestBytes, err := json.Marshal(params)
if err != nil {
return err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/msp/tenants/"+tenantID+"/unlink", bytes.NewReader(requestBytes), nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// VerifyTenantDNS verifies a tenant domain DNS challenge
// See more: https://docs.netbird.io/api/resources/msp#verify-tenant-dns
func (a *MSPAPI) VerifyTenantDNS(ctx context.Context, tenantID string) error {
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/msp/tenants/"+tenantID+"/dns", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// InviteTenant invites an existing account as a tenant to the MSP account
// See more: https://docs.netbird.io/api/resources/msp#invite-a-tenant
func (a *MSPAPI) InviteTenant(ctx context.Context, tenantID string) (*api.TenantResponse, error) {
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/msp/tenants/"+tenantID+"/invite", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.TenantResponse](resp)
return &ret, err
}

View File

@@ -0,0 +1,251 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testTenant = api.TenantResponse{
Id: "tenant-1",
Name: "Test Tenant",
Domain: "test.example.com",
DnsChallenge: "challenge-123",
Status: "active",
Groups: []api.TenantGroupResponse{},
CreatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
UpdatedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
)
func TestMSP_ListTenants_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.TenantResponse{testTenant})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.ListTenants(context.Background())
require.NoError(t, err)
assert.Len(t, *ret, 1)
assert.Equal(t, testTenant, (*ret)[0])
})
}
func TestMSP_ListTenants_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.ListTenants(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestMSP_CreateTenant_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.CreateTenantRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "Test Tenant", req.Name)
assert.Equal(t, "test.example.com", req.Domain)
retBytes, _ := json.Marshal(testTenant)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.CreateTenant(context.Background(), api.CreateTenantRequest{
Name: "Test Tenant",
Domain: "test.example.com",
Groups: []api.TenantGroupResponse{},
})
require.NoError(t, err)
assert.Equal(t, testTenant, *ret)
})
}
func TestMSP_CreateTenant_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.CreateTenant(context.Background(), api.CreateTenantRequest{
Name: "Test Tenant",
Domain: "test.example.com",
Groups: []api.TenantGroupResponse{},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestMSP_UpdateTenant_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.UpdateTenantRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "Updated Tenant", req.Name)
retBytes, _ := json.Marshal(testTenant)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.UpdateTenant(context.Background(), "tenant-1", api.UpdateTenantRequest{
Name: "Updated Tenant",
Groups: []api.TenantGroupResponse{},
})
require.NoError(t, err)
assert.Equal(t, testTenant, *ret)
})
}
func TestMSP_UpdateTenant_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.UpdateTenant(context.Background(), "tenant-1", api.UpdateTenantRequest{
Name: "Updated Tenant",
Groups: []api.TenantGroupResponse{},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestMSP_DeleteTenant_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.MSP.DeleteTenant(context.Background(), "tenant-1")
require.NoError(t, err)
})
}
func TestMSP_DeleteTenant_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.MSP.DeleteTenant(context.Background(), "tenant-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestMSP_UnlinkTenant_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/unlink", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
w.WriteHeader(200)
})
err := c.MSP.UnlinkTenant(context.Background(), "tenant-1", "owner-1")
require.NoError(t, err)
})
}
func TestMSP_UnlinkTenant_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/unlink", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.MSP.UnlinkTenant(context.Background(), "tenant-1", "owner-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestMSP_VerifyTenantDNS_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/dns", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
w.WriteHeader(200)
})
err := c.MSP.VerifyTenantDNS(context.Background(), "tenant-1")
require.NoError(t, err)
})
}
func TestMSP_VerifyTenantDNS_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/dns", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Failed", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.MSP.VerifyTenantDNS(context.Background(), "tenant-1")
assert.Error(t, err)
assert.Equal(t, "Failed", err.Error())
})
}
func TestMSP_InviteTenant_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/invite", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testTenant)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.InviteTenant(context.Background(), "tenant-1")
require.NoError(t, err)
assert.Equal(t, testTenant, *ret)
})
}
func TestMSP_InviteTenant_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/msp/tenants/tenant-1/invite", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.MSP.InviteTenant(context.Background(), "tenant-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}

View File

@@ -91,6 +91,20 @@ func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error {
return nil
}
// ListAllRouters list all routers across all networks
// See more: https://docs.netbird.io/api/resources/networks#list-all-network-routers
func (a *NetworksAPI) ListAllRouters(ctx context.Context) ([]api.NetworkRouter, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/routers", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.NetworkRouter](resp)
return ret, err
}
// NetworkResourcesAPI APIs for Network Resources, do not use directly
type NetworkResourcesAPI struct {
c *Client

View File

@@ -219,6 +219,35 @@ func TestNetworks_Integration(t *testing.T) {
})
}
func TestNetworks_ListAllRouters_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.NetworkRouter{testNetworkRouter})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.ListAllRouters(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testNetworkRouter, ret[0])
})
}
func TestNetworks_ListAllRouters_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.ListAllRouters(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworkResources_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {

View File

@@ -106,3 +106,173 @@ func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]ap
ret, err := parseResponse[[]api.Peer](resp)
return ret, err
}
// CreateTemporaryAccess create temporary access for a peer
// See more: https://docs.netbird.io/api/resources/peers#create-temporary-access
func (a *PeersAPI) CreateTemporaryAccess(ctx context.Context, peerID string, request api.PostApiPeersPeerIdTemporaryAccessJSONRequestBody) (*api.PeerTemporaryAccessResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/peers/"+peerID+"/temporary-access", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.PeerTemporaryAccessResponse](resp)
return &ret, err
}
// PeerIngressPortsAPI APIs for Peer Ingress Ports, do not use directly
type PeerIngressPortsAPI struct {
c *Client
peerID string
}
// IngressPorts APIs for peer ingress ports
func (a *PeersAPI) IngressPorts(peerID string) *PeerIngressPortsAPI {
return &PeerIngressPortsAPI{
c: a.c,
peerID: peerID,
}
}
// List list all ingress port allocations for a peer
// See more: https://docs.netbird.io/api/resources/peers#list-all-ingress-port-allocations
func (a *PeerIngressPortsAPI) List(ctx context.Context) ([]api.IngressPortAllocation, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+a.peerID+"/ingress/ports", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IngressPortAllocation](resp)
return ret, err
}
// Get get ingress port allocation info
// See more: https://docs.netbird.io/api/resources/peers#retrieve-an-ingress-port-allocation
func (a *PeerIngressPortsAPI) Get(ctx context.Context, allocationID string) (*api.IngressPortAllocation, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+a.peerID+"/ingress/ports/"+allocationID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IngressPortAllocation](resp)
return &ret, err
}
// Create create new ingress port allocation
// See more: https://docs.netbird.io/api/resources/peers#create-an-ingress-port-allocation
func (a *PeerIngressPortsAPI) Create(ctx context.Context, request api.PostApiPeersPeerIdIngressPortsJSONRequestBody) (*api.IngressPortAllocation, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/peers/"+a.peerID+"/ingress/ports", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IngressPortAllocation](resp)
return &ret, err
}
// Update update ingress port allocation
// See more: https://docs.netbird.io/api/resources/peers#update-an-ingress-port-allocation
func (a *PeerIngressPortsAPI) Update(ctx context.Context, allocationID string, request api.PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody) (*api.IngressPortAllocation, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/peers/"+a.peerID+"/ingress/ports/"+allocationID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.IngressPortAllocation](resp)
return &ret, err
}
// Delete delete ingress port allocation
// See more: https://docs.netbird.io/api/resources/peers#delete-an-ingress-port-allocation
func (a *PeerIngressPortsAPI) Delete(ctx context.Context, allocationID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+a.peerID+"/ingress/ports/"+allocationID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// PeerJobsAPI APIs for Peer Jobs, do not use directly
type PeerJobsAPI struct {
c *Client
peerID string
}
// Jobs APIs for peer jobs
func (a *PeersAPI) Jobs(peerID string) *PeerJobsAPI {
return &PeerJobsAPI{
c: a.c,
peerID: peerID,
}
}
// List list all jobs for a peer
// See more: https://docs.netbird.io/api/resources/peers#list-all-peer-jobs
func (a *PeerJobsAPI) List(ctx context.Context) ([]api.JobResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+a.peerID+"/jobs", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.JobResponse](resp)
return ret, err
}
// Get get job info
// See more: https://docs.netbird.io/api/resources/peers#retrieve-a-peer-job
func (a *PeerJobsAPI) Get(ctx context.Context, jobID string) (*api.JobResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+a.peerID+"/jobs/"+jobID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.JobResponse](resp)
return &ret, err
}
// Create create new job for a peer
// See more: https://docs.netbird.io/api/resources/peers#create-a-peer-job
func (a *PeerJobsAPI) Create(ctx context.Context, request api.PostApiPeersPeerIdJobsJSONRequestBody) (*api.JobResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/peers/"+a.peerID+"/jobs", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.JobResponse](resp)
return &ret, err
}

View File

@@ -25,6 +25,21 @@ var (
DnsLabel: "test",
Id: "Test",
}
testPeerTemporaryAccess = api.PeerTemporaryAccessResponse{
Id: "Test",
Name: "test-peer",
}
testIngressPortAllocation = api.IngressPortAllocation{
Enabled: true,
Id: "alloc-1",
}
testJobResponse = api.JobResponse{
Id: "job-1",
Status: "pending",
}
)
func TestPeers_List_200(t *testing.T) {
@@ -177,6 +192,264 @@ func TestPeers_ListAccessiblePeers_Err(t *testing.T) {
})
}
func TestPeers_CreateTemporaryAccess_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/temporary-access", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testPeerTemporaryAccess)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.CreateTemporaryAccess(context.Background(), "Test", api.PostApiPeersPeerIdTemporaryAccessJSONRequestBody{})
require.NoError(t, err)
assert.Equal(t, testPeerTemporaryAccess, *ret)
})
}
func TestPeers_CreateTemporaryAccess_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/temporary-access", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.CreateTemporaryAccess(context.Background(), "Test", api.PostApiPeersPeerIdTemporaryAccessJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPeerIngressPorts_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.IngressPortAllocation{testIngressPortAllocation})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testIngressPortAllocation, ret[0])
})
}
func TestPeerIngressPorts_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeerIngressPorts_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports/alloc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testIngressPortAllocation)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").Get(context.Background(), "alloc-1")
require.NoError(t, err)
assert.Equal(t, testIngressPortAllocation, *ret)
})
}
func TestPeerIngressPorts_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports/alloc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").Get(context.Background(), "alloc-1")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeerIngressPorts_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testIngressPortAllocation)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").Create(context.Background(), api.PostApiPeersPeerIdIngressPortsJSONRequestBody{})
require.NoError(t, err)
assert.Equal(t, testIngressPortAllocation, *ret)
})
}
func TestPeerIngressPorts_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").Create(context.Background(), api.PostApiPeersPeerIdIngressPortsJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPeerIngressPorts_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports/alloc-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
retBytes, _ := json.Marshal(testIngressPortAllocation)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").Update(context.Background(), "alloc-1", api.PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody{})
require.NoError(t, err)
assert.Equal(t, testIngressPortAllocation, *ret)
})
}
func TestPeerIngressPorts_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports/alloc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.IngressPorts("Test").Update(context.Background(), "alloc-1", api.PutApiPeersPeerIdIngressPortsAllocationIdJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPeerIngressPorts_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports/alloc-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Peers.IngressPorts("Test").Delete(context.Background(), "alloc-1")
require.NoError(t, err)
})
}
func TestPeerIngressPorts_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/ingress/ports/alloc-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Peers.IngressPorts("Test").Delete(context.Background(), "alloc-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestPeerJobs_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/jobs", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.JobResponse{testJobResponse})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Jobs("Test").List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testJobResponse.Id, ret[0].Id)
assert.Equal(t, testJobResponse.Status, ret[0].Status)
})
}
func TestPeerJobs_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/jobs", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Jobs("Test").List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeerJobs_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/jobs/job-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testJobResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Jobs("Test").Get(context.Background(), "job-1")
require.NoError(t, err)
assert.Equal(t, testJobResponse.Id, ret.Id)
assert.Equal(t, testJobResponse.Status, ret.Status)
})
}
func TestPeerJobs_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/jobs/job-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Jobs("Test").Get(context.Background(), "job-1")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeerJobs_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/jobs", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testJobResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Jobs("Test").Create(context.Background(), api.PostApiPeersPeerIdJobsJSONRequestBody{})
require.NoError(t, err)
assert.Equal(t, testJobResponse.Id, ret.Id)
assert.Equal(t, testJobResponse.Status, ret.Status)
})
}
func TestPeerJobs_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/jobs", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Jobs("Test").Create(context.Background(), api.PostApiPeersPeerIdJobsJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPeers_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
peers, err := c.Peers.List(context.Background())

View File

@@ -0,0 +1,119 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// SCIMAPI APIs for SCIM IDP integrations
type SCIMAPI struct {
c *Client
}
// List retrieves all SCIM IDP integrations
// See more: https://docs.netbird.io/api/resources/scim#list-all-scim-integrations
func (a *SCIMAPI) List(ctx context.Context) ([]api.ScimIntegration, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/scim-idp", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.ScimIntegration](resp)
return ret, err
}
// Get retrieves a specific SCIM IDP integration by ID
// See more: https://docs.netbird.io/api/resources/scim#retrieve-a-scim-integration
func (a *SCIMAPI) Get(ctx context.Context, integrationID string) (*api.ScimIntegration, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/scim-idp/"+integrationID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.ScimIntegration](resp)
return &ret, err
}
// Create creates a new SCIM IDP integration
// See more: https://docs.netbird.io/api/resources/scim#create-a-scim-integration
func (a *SCIMAPI) Create(ctx context.Context, request api.CreateScimIntegrationRequest) (*api.ScimIntegration, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/scim-idp", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.ScimIntegration](resp)
return &ret, err
}
// Update updates an existing SCIM IDP integration
// See more: https://docs.netbird.io/api/resources/scim#update-a-scim-integration
func (a *SCIMAPI) Update(ctx context.Context, integrationID string, request api.UpdateScimIntegrationRequest) (*api.ScimIntegration, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/scim-idp/"+integrationID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.ScimIntegration](resp)
return &ret, err
}
// Delete deletes a SCIM IDP integration
// See more: https://docs.netbird.io/api/resources/scim#delete-a-scim-integration
func (a *SCIMAPI) Delete(ctx context.Context, integrationID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/scim-idp/"+integrationID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// RegenerateToken regenerates the SCIM API token for an integration
// See more: https://docs.netbird.io/api/resources/scim#regenerate-scim-token
func (a *SCIMAPI) RegenerateToken(ctx context.Context, integrationID string) (*api.ScimTokenResponse, error) {
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/scim-idp/"+integrationID+"/token", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.ScimTokenResponse](resp)
return &ret, err
}
// GetLogs retrieves synchronization logs for an SCIM IDP integration
// See more: https://docs.netbird.io/api/resources/scim#get-scim-sync-logs
func (a *SCIMAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/scim-idp/"+integrationID+"/logs", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp)
return ret, err
}

View File

@@ -0,0 +1,262 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testScimIntegration = api.ScimIntegration{
Id: 1,
AuthToken: "****",
Enabled: true,
GroupPrefixes: []string{"eng-"},
UserGroupPrefixes: []string{"dev-"},
Provider: "okta",
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
testScimToken = api.ScimTokenResponse{
AuthToken: "new-token-123",
}
testSyncLog = api.IdpIntegrationSyncLog{
Id: 1,
Level: "info",
Message: "Sync completed",
Timestamp: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
)
func TestSCIM_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.ScimIntegration{testScimIntegration})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testScimIntegration, ret[0])
})
}
func TestSCIM_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestSCIM_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testScimIntegration)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.Get(context.Background(), "int-1")
require.NoError(t, err)
assert.Equal(t, testScimIntegration, *ret)
})
}
func TestSCIM_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.Get(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestSCIM_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.CreateScimIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "okta", req.Provider)
assert.Equal(t, "scim-", req.Prefix)
retBytes, _ := json.Marshal(testScimIntegration)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.Create(context.Background(), api.CreateScimIntegrationRequest{
Provider: "okta",
Prefix: "scim-",
GroupPrefixes: &[]string{"eng-"},
})
require.NoError(t, err)
assert.Equal(t, testScimIntegration, *ret)
})
}
func TestSCIM_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.Create(context.Background(), api.CreateScimIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestSCIM_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.UpdateScimIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, *req.Enabled)
retBytes, _ := json.Marshal(testScimIntegration)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.Update(context.Background(), "int-1", api.UpdateScimIntegrationRequest{
Enabled: ptr(true),
})
require.NoError(t, err)
assert.Equal(t, testScimIntegration, *ret)
})
}
func TestSCIM_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.Update(context.Background(), "int-1", api.UpdateScimIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestSCIM_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.SCIM.Delete(context.Background(), "int-1")
require.NoError(t, err)
})
}
func TestSCIM_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.SCIM.Delete(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestSCIM_RegenerateToken_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1/token", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testScimToken)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.RegenerateToken(context.Background(), "int-1")
require.NoError(t, err)
assert.Equal(t, testScimToken, *ret)
})
}
func TestSCIM_RegenerateToken_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1/token", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.RegenerateToken(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestSCIM_GetLogs_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.GetLogs(context.Background(), "int-1")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testSyncLog, ret[0])
})
}
func TestSCIM_GetLogs_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/scim-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SCIM.GetLogs(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Empty(t, ret)
})
}

View File

@@ -105,3 +105,145 @@ func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) {
ret, err := parseResponse[api.User](resp)
return &ret, err
}
// ListInvites list all user invites
// See more: https://docs.netbird.io/api/resources/users#list-all-user-invites
func (a *UsersAPI) ListInvites(ctx context.Context) ([]api.UserInvite, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/invites", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.UserInvite](resp)
return ret, err
}
// CreateInvite create a user invite
// See more: https://docs.netbird.io/api/resources/users#create-a-user-invite
func (a *UsersAPI) CreateInvite(ctx context.Context, request api.PostApiUsersInvitesJSONRequestBody) (*api.UserInvite, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/invites", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.UserInvite](resp)
return &ret, err
}
// DeleteInvite delete a user invite
// See more: https://docs.netbird.io/api/resources/users#delete-a-user-invite
func (a *UsersAPI) DeleteInvite(ctx context.Context, inviteID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/invites/"+inviteID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// RegenerateInvite regenerate a user invite token
// See more: https://docs.netbird.io/api/resources/users#regenerate-a-user-invite
func (a *UsersAPI) RegenerateInvite(ctx context.Context, inviteID string, request api.PostApiUsersInvitesInviteIdRegenerateJSONRequestBody) (*api.UserInviteRegenerateResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/invites/"+inviteID+"/regenerate", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.UserInviteRegenerateResponse](resp)
return &ret, err
}
// GetInviteByToken get a user invite by token
// See more: https://docs.netbird.io/api/resources/users#get-a-user-invite-by-token
func (a *UsersAPI) GetInviteByToken(ctx context.Context, token string) (*api.UserInviteInfo, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/invites/"+token, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.UserInviteInfo](resp)
return &ret, err
}
// AcceptInvite accept a user invite
// See more: https://docs.netbird.io/api/resources/users#accept-a-user-invite
func (a *UsersAPI) AcceptInvite(ctx context.Context, token string, request api.PostApiUsersInvitesTokenAcceptJSONRequestBody) (*api.UserInviteAcceptResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/invites/"+token+"/accept", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.UserInviteAcceptResponse](resp)
return &ret, err
}
// Approve approve a pending user
// See more: https://docs.netbird.io/api/resources/users#approve-a-user
func (a *UsersAPI) Approve(ctx context.Context, userID string) (*api.User, error) {
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/approve", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.User](resp)
return &ret, err
}
// ChangePassword change a user's password
// See more: https://docs.netbird.io/api/resources/users#change-user-password
func (a *UsersAPI) ChangePassword(ctx context.Context, userID string, request api.PutApiUsersUserIdPasswordJSONRequestBody) error {
requestBytes, err := json.Marshal(request)
if err != nil {
return err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/users/"+userID+"/password", bytes.NewReader(requestBytes), nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// Reject reject a pending user
// See more: https://docs.netbird.io/api/resources/users#reject-a-user
func (a *UsersAPI) Reject(ctx context.Context, userID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID+"/reject", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -32,6 +32,23 @@ var (
Role: "user",
Status: api.UserStatusActive,
}
testUserInvite = api.UserInvite{
AutoGroups: []string{"group1"},
Id: "invite-1",
}
testUserInviteInfo = api.UserInviteInfo{
Email: "invite@test.com",
}
testUserInviteAcceptResponse = api.UserInviteAcceptResponse{
Success: true,
}
testUserInviteRegenerateResponse = api.UserInviteRegenerateResponse{
InviteToken: "new-token",
}
)
func TestUsers_List_200(t *testing.T) {
@@ -220,6 +237,269 @@ func TestUsers_Current_Err(t *testing.T) {
})
}
func TestUsers_ListInvites_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.UserInvite{testUserInvite})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.ListInvites(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testUserInvite, ret[0])
})
}
func TestUsers_ListInvites_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.ListInvites(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestUsers_CreateInvite_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiUsersInvitesJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "invite@test.com", req.Email)
retBytes, _ := json.Marshal(testUserInvite)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.CreateInvite(context.Background(), api.PostApiUsersInvitesJSONRequestBody{
Email: "invite@test.com",
})
require.NoError(t, err)
assert.Equal(t, testUserInvite, *ret)
})
}
func TestUsers_CreateInvite_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.CreateInvite(context.Background(), api.PostApiUsersInvitesJSONRequestBody{
Email: "invite@test.com",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestUsers_DeleteInvite_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/invite-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Users.DeleteInvite(context.Background(), "invite-1")
require.NoError(t, err)
})
}
func TestUsers_DeleteInvite_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/invite-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Users.DeleteInvite(context.Background(), "invite-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestUsers_RegenerateInvite_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/invite-1/regenerate", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testUserInviteRegenerateResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.RegenerateInvite(context.Background(), "invite-1", api.PostApiUsersInvitesInviteIdRegenerateJSONRequestBody{})
require.NoError(t, err)
assert.Equal(t, testUserInviteRegenerateResponse, *ret)
})
}
func TestUsers_RegenerateInvite_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/invite-1/regenerate", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.RegenerateInvite(context.Background(), "invite-1", api.PostApiUsersInvitesInviteIdRegenerateJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestUsers_GetInviteByToken_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/some-token", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testUserInviteInfo)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.GetInviteByToken(context.Background(), "some-token")
require.NoError(t, err)
assert.Equal(t, testUserInviteInfo, *ret)
})
}
func TestUsers_GetInviteByToken_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/some-token", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.GetInviteByToken(context.Background(), "some-token")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestUsers_AcceptInvite_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/some-token/accept", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testUserInviteAcceptResponse)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.AcceptInvite(context.Background(), "some-token", api.PostApiUsersInvitesTokenAcceptJSONRequestBody{})
require.NoError(t, err)
assert.Equal(t, testUserInviteAcceptResponse, *ret)
})
}
func TestUsers_AcceptInvite_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/invites/some-token/accept", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.AcceptInvite(context.Background(), "some-token", api.PostApiUsersInvitesTokenAcceptJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestUsers_Approve_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/approve", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testUser)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Approve(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testUser, *ret)
})
}
func TestUsers_Approve_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/approve", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Approve(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestUsers_ChangePassword_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/password", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiUsersUserIdPasswordJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
w.WriteHeader(200)
})
err := c.Users.ChangePassword(context.Background(), "Test", api.PutApiUsersUserIdPasswordJSONRequestBody{})
require.NoError(t, err)
})
}
func TestUsers_ChangePassword_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/password", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Users.ChangePassword(context.Background(), "Test", api.PutApiUsersUserIdPasswordJSONRequestBody{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
})
}
func TestUsers_Reject_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/reject", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Users.Reject(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestUsers_Reject_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/reject", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Users.Reject(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
})
}
func TestUsers_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
// rest client PAT is owner's

File diff suppressed because it is too large Load Diff

View File

@@ -16,6 +16,14 @@ const (
TokenAuthScopes = "TokenAuth.Scopes"
)
// Defines values for CreateIntegrationRequestPlatform.
const (
CreateIntegrationRequestPlatformDatadog CreateIntegrationRequestPlatform = "datadog"
CreateIntegrationRequestPlatformFirehose CreateIntegrationRequestPlatform = "firehose"
CreateIntegrationRequestPlatformGenericHttp CreateIntegrationRequestPlatform = "generic_http"
CreateIntegrationRequestPlatformS3 CreateIntegrationRequestPlatform = "s3"
)
// Defines values for DNSRecordType.
const (
DNSRecordTypeA DNSRecordType = "A"
@@ -188,6 +196,20 @@ const (
IngressPortAllocationRequestPortRangeProtocolUdp IngressPortAllocationRequestPortRangeProtocol = "udp"
)
// Defines values for IntegrationResponsePlatform.
const (
IntegrationResponsePlatformDatadog IntegrationResponsePlatform = "datadog"
IntegrationResponsePlatformFirehose IntegrationResponsePlatform = "firehose"
IntegrationResponsePlatformGenericHttp IntegrationResponsePlatform = "generic_http"
IntegrationResponsePlatformS3 IntegrationResponsePlatform = "s3"
)
// Defines values for InvoiceResponseType.
const (
InvoiceResponseTypeAccount InvoiceResponseType = "account"
InvoiceResponseTypeTenants InvoiceResponseType = "tenants"
)
// Defines values for JobResponseStatus.
const (
JobResponseStatusFailed JobResponseStatus = "failed"
@@ -266,6 +288,21 @@ const (
ResourceTypeSubnet ResourceType = "subnet"
)
// Defines values for SentinelOneMatchAttributesNetworkStatus.
const (
SentinelOneMatchAttributesNetworkStatusConnected SentinelOneMatchAttributesNetworkStatus = "connected"
SentinelOneMatchAttributesNetworkStatusDisconnected SentinelOneMatchAttributesNetworkStatus = "disconnected"
SentinelOneMatchAttributesNetworkStatusQuarantined SentinelOneMatchAttributesNetworkStatus = "quarantined"
)
// Defines values for TenantResponseStatus.
const (
TenantResponseStatusActive TenantResponseStatus = "active"
TenantResponseStatusExisting TenantResponseStatus = "existing"
TenantResponseStatusInvited TenantResponseStatus = "invited"
TenantResponseStatusPending TenantResponseStatus = "pending"
)
// Defines values for UserStatus.
const (
UserStatusActive UserStatus = "active"
@@ -299,6 +336,12 @@ const (
GetApiEventsNetworkTrafficParamsDirectionINGRESS GetApiEventsNetworkTrafficParamsDirection = "INGRESS"
)
// Defines values for PutApiIntegrationsMspTenantsIdInviteJSONBodyValue.
const (
PutApiIntegrationsMspTenantsIdInviteJSONBodyValueAccept PutApiIntegrationsMspTenantsIdInviteJSONBodyValue = "accept"
PutApiIntegrationsMspTenantsIdInviteJSONBodyValueDecline PutApiIntegrationsMspTenantsIdInviteJSONBodyValue = "decline"
)
// AccessiblePeer defines model for AccessiblePeer.
type AccessiblePeer struct {
// CityName Commonly used English name of the city
@@ -490,6 +533,21 @@ type BundleWorkloadResponse struct {
Type WorkloadType `json:"type"`
}
// BypassResponse Response for bypassed peer operations.
type BypassResponse struct {
// PeerId The ID of the bypassed peer.
PeerId string `json:"peer_id"`
}
// CheckoutResponse defines model for CheckoutResponse.
type CheckoutResponse struct {
// SessionId The unique identifier for the checkout session.
SessionId string `json:"session_id"`
// Url URL to redirect the user to the checkout session.
Url string `json:"url"`
}
// Checks List of objects that perform the actual checks
type Checks struct {
// GeoLocationCheck Posture check for geo location
@@ -532,6 +590,36 @@ type Country struct {
// CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
type CountryCode = string
// CreateIntegrationRequest Request payload for creating a new event streaming integration. Also used as the structure for the PUT request body, but not all fields are applicable for updates (see PUT operation description).
type CreateIntegrationRequest struct {
// Config Platform-specific configuration as key-value pairs. For creation, all necessary credentials and settings must be provided. For updates, provide the fields to change or the entire new configuration.
Config map[string]string `json:"config"`
// Enabled Specifies whether the integration is enabled. During creation (POST), this value is sent by the client, but the provided backend manager function `CreateIntegration` does not appear to use it directly, so its effect on creation should be verified. During updates (PUT), this field is used to enable or disable the integration.
Enabled bool `json:"enabled"`
// Platform The event streaming platform to integrate with (e.g., "datadog", "s3", "firehose"). This field is used for creation. For updates (PUT), this field, if sent, is ignored by the backend.
Platform CreateIntegrationRequestPlatform `json:"platform"`
}
// CreateIntegrationRequestPlatform The event streaming platform to integrate with (e.g., "datadog", "s3", "firehose"). This field is used for creation. For updates (PUT), this field, if sent, is ignored by the backend.
type CreateIntegrationRequestPlatform string
// CreateScimIntegrationRequest Request payload for creating an SCIM IDP integration
type CreateScimIntegrationRequest struct {
// GroupPrefixes List of start_with string patterns for groups to sync
GroupPrefixes *[]string `json:"group_prefixes,omitempty"`
// Prefix The connection prefix used for the SCIM provider
Prefix string `json:"prefix"`
// Provider Name of the SCIM identity provider
Provider string `json:"provider"`
// UserGroupPrefixes List of start_with string patterns for groups which users to sync
UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"`
}
// CreateSetupKeyRequest defines model for CreateSetupKeyRequest.
type CreateSetupKeyRequest struct {
// AllowExtraDnsLabels Allow extra DNS labels to be added to the peer
@@ -556,6 +644,24 @@ type CreateSetupKeyRequest struct {
UsageLimit int `json:"usage_limit"`
}
// CreateTenantRequest defines model for CreateTenantRequest.
type CreateTenantRequest struct {
// Domain The name for the MSP tenant
Domain string `json:"domain"`
// Groups MSP users Groups that can access the Tenant and Roles to assume
Groups []TenantGroupResponse `json:"groups"`
// Name The name for the MSP tenant
Name string `json:"name"`
}
// DNSChallengeResponse defines model for DNSChallengeResponse.
type DNSChallengeResponse struct {
// DnsChallenge The DNS challenge to set in a TXT record
DnsChallenge string `json:"dns_challenge"`
}
// DNSRecord defines model for DNSRecord.
type DNSRecord struct {
// Content DNS record content (IP address for A/AAAA, domain for CNAME)
@@ -598,6 +704,234 @@ type DNSSettings struct {
DisabledManagementGroups []string `json:"disabled_management_groups"`
}
// EDRFalconRequest Request payload for creating or updating a EDR Falcon integration
type EDRFalconRequest struct {
// ClientId CrowdStrike API client ID
ClientId string `json:"client_id"`
// CloudId CrowdStrike cloud identifier (e.g., "us-1", "us-2", "eu-1")
CloudId string `json:"cloud_id"`
// Enabled Indicates whether the integration is enabled
Enabled *bool `json:"enabled,omitempty"`
// Groups The Groups this integration applies to
Groups []string `json:"groups"`
// Secret CrowdStrike API client secret
Secret string `json:"secret"`
// ZtaScoreThreshold The minimum Zero Trust Assessment score required for agent approval (0-100)
ZtaScoreThreshold int `json:"zta_score_threshold"`
}
// EDRFalconResponse Represents a Falcon EDR integration
type EDRFalconResponse struct {
// AccountId The identifier of the account this integration belongs to.
AccountId string `json:"account_id"`
// CloudId CrowdStrike cloud identifier
CloudId string `json:"cloud_id"`
// CreatedAt Timestamp of when the integration was created.
CreatedAt time.Time `json:"created_at"`
// CreatedBy The user id that created the integration
CreatedBy string `json:"created_by"`
// Enabled Indicates whether the integration is enabled
Enabled bool `json:"enabled"`
// Groups List of groups
Groups []Group `json:"groups"`
// Id The unique numeric identifier for the integration.
Id int64 `json:"id"`
// LastSyncedAt Timestamp of when the integration was last synced.
LastSyncedAt time.Time `json:"last_synced_at"`
// UpdatedAt Timestamp of when the integration was last updated.
UpdatedAt time.Time `json:"updated_at"`
// ZtaScoreThreshold The minimum Zero Trust Assessment score required for agent approval (0-100)
ZtaScoreThreshold int `json:"zta_score_threshold"`
}
// EDRHuntressRequest Request payload for creating or updating a EDR Huntress integration
type EDRHuntressRequest struct {
// ApiKey Huntress API key
ApiKey string `json:"api_key"`
// ApiSecret Huntress API secret
ApiSecret string `json:"api_secret"`
// Enabled Indicates whether the integration is enabled
Enabled *bool `json:"enabled,omitempty"`
// Groups The Groups this integrations applies to
Groups []string `json:"groups"`
// LastSyncedInterval The devices last sync requirement interval in hours. Minimum value is 24 hours
LastSyncedInterval int `json:"last_synced_interval"`
// MatchAttributes Attribute conditions to match when approving agents
MatchAttributes HuntressMatchAttributes `json:"match_attributes"`
}
// EDRHuntressResponse Represents a Huntress EDR integration configuration
type EDRHuntressResponse struct {
// AccountId The identifier of the account this integration belongs to.
AccountId string `json:"account_id"`
// CreatedAt Timestamp of when the integration was created.
CreatedAt time.Time `json:"created_at"`
// CreatedBy The user id that created the integration
CreatedBy string `json:"created_by"`
// Enabled Indicates whether the integration is enabled
Enabled bool `json:"enabled"`
// Groups List of groups
Groups []Group `json:"groups"`
// Id The unique numeric identifier for the integration.
Id int64 `json:"id"`
// LastSyncedAt Timestamp of when the integration was last synced.
LastSyncedAt time.Time `json:"last_synced_at"`
// LastSyncedInterval The devices last sync requirement interval in hours.
LastSyncedInterval int `json:"last_synced_interval"`
// MatchAttributes Attribute conditions to match when approving agents
MatchAttributes HuntressMatchAttributes `json:"match_attributes"`
// UpdatedAt Timestamp of when the integration was last updated.
UpdatedAt time.Time `json:"updated_at"`
}
// EDRIntuneRequest Request payload for creating or updating a EDR Intune integration.
type EDRIntuneRequest struct {
// ClientId The Azure application client id
ClientId string `json:"client_id"`
// Enabled Indicates whether the integration is enabled
Enabled *bool `json:"enabled,omitempty"`
// Groups The Groups this integrations applies to
Groups []string `json:"groups"`
// LastSyncedInterval The devices last sync requirement interval in hours. Minimum value is 24 hours.
LastSyncedInterval int `json:"last_synced_interval"`
// Secret The Azure application client secret
Secret string `json:"secret"`
// TenantId The Azure tenant id
TenantId string `json:"tenant_id"`
}
// EDRIntuneResponse Represents a Intune EDR integration configuration
type EDRIntuneResponse struct {
// AccountId The identifier of the account this integration belongs to.
AccountId string `json:"account_id"`
// ClientId The Azure application client id
ClientId string `json:"client_id"`
// CreatedAt Timestamp of when the integration was created.
CreatedAt time.Time `json:"created_at"`
// CreatedBy The user id that created the integration
CreatedBy string `json:"created_by"`
// Enabled Indicates whether the integration is enabled
Enabled bool `json:"enabled"`
// Groups List of groups
Groups []Group `json:"groups"`
// Id The unique numeric identifier for the integration.
Id int64 `json:"id"`
// LastSyncedAt Timestamp of when the integration was last synced.
LastSyncedAt time.Time `json:"last_synced_at"`
// LastSyncedInterval The devices last sync requirement interval in hours.
LastSyncedInterval int `json:"last_synced_interval"`
// TenantId The Azure tenant id
TenantId string `json:"tenant_id"`
// UpdatedAt Timestamp of when the integration was last updated.
UpdatedAt time.Time `json:"updated_at"`
}
// EDRSentinelOneRequest Request payload for creating or updating a EDR SentinelOne integration
type EDRSentinelOneRequest struct {
// ApiToken SentinelOne API token
ApiToken string `json:"api_token"`
// ApiUrl The Base URL of SentinelOne API
ApiUrl string `json:"api_url"`
// Enabled Indicates whether the integration is enabled
Enabled *bool `json:"enabled,omitempty"`
// Groups The Groups this integrations applies to
Groups []string `json:"groups"`
// LastSyncedInterval The devices last sync requirement interval in hours. Minimum value is 24 hours.
LastSyncedInterval int `json:"last_synced_interval"`
// MatchAttributes Attribute conditions to match when approving agents
MatchAttributes SentinelOneMatchAttributes `json:"match_attributes"`
}
// EDRSentinelOneResponse Represents a SentinelOne EDR integration configuration
type EDRSentinelOneResponse struct {
// AccountId The identifier of the account this integration belongs to.
AccountId string `json:"account_id"`
// ApiUrl The Base URL of SentinelOne API
ApiUrl string `json:"api_url"`
// CreatedAt Timestamp of when the integration was created.
CreatedAt time.Time `json:"created_at"`
// CreatedBy The user id that created the integration
CreatedBy string `json:"created_by"`
// Enabled Indicates whether the integration is enabled
Enabled bool `json:"enabled"`
// Groups List of groups
Groups []Group `json:"groups"`
// Id The unique numeric identifier for the integration.
Id int64 `json:"id"`
// LastSyncedAt Timestamp of when the integration was last synced.
LastSyncedAt time.Time `json:"last_synced_at"`
// LastSyncedInterval The devices last sync requirement interval in hours.
LastSyncedInterval int `json:"last_synced_interval"`
// MatchAttributes Attribute conditions to match when approving agents
MatchAttributes SentinelOneMatchAttributes `json:"match_attributes"`
// UpdatedAt Timestamp of when the integration was last updated.
UpdatedAt time.Time `json:"updated_at"`
}
// ErrorResponse Standard error response. Note: The exact structure of this error response is inferred from `util.WriteErrorResponse` and `util.WriteError` usage in the provided Go code, as a specific Go struct for errors was not provided.
type ErrorResponse struct {
// Message A human-readable error message.
Message *string `json:"message,omitempty"`
}
// Event defines model for Event.
type Event struct {
// Activity The activity that occurred during the event
@@ -643,6 +977,9 @@ type GeoLocationCheck struct {
// GeoLocationCheckAction Action to take upon policy match
type GeoLocationCheckAction string
// GetTenantsResponse defines model for GetTenantsResponse.
type GetTenantsResponse = []TenantResponse
// Group defines model for Group.
type Group struct {
// Id Group ID
@@ -699,6 +1036,21 @@ type GroupRequest struct {
Resources *[]Resource `json:"resources,omitempty"`
}
// HuntressMatchAttributes Attribute conditions to match when approving agents
type HuntressMatchAttributes struct {
// DefenderPolicyStatus Policy status of Defender AV for Managed Antivirus.
DefenderPolicyStatus *string `json:"defender_policy_status,omitempty"`
// DefenderStatus Status of Defender AV Managed Antivirus.
DefenderStatus *string `json:"defender_status,omitempty"`
// DefenderSubstatus Sub-status of Defender AV Managed Antivirus.
DefenderSubstatus *string `json:"defender_substatus,omitempty"`
// FirewallStatus Status of agent firewall. Can be one of Disabled, Enabled, Pending Isolation, Isolated, Pending Release.
FirewallStatus *string `json:"firewall_status,omitempty"`
}
// IdentityProvider defines model for IdentityProvider.
type IdentityProvider struct {
// ClientId OAuth2 client ID
@@ -738,6 +1090,21 @@ type IdentityProviderRequest struct {
// IdentityProviderType Type of identity provider
type IdentityProviderType string
// IdpIntegrationSyncLog Represents a synchronization log entry for an integration
type IdpIntegrationSyncLog struct {
// Id The unique identifier for the sync log
Id int64 `json:"id"`
// Level The log level
Level string `json:"level"`
// Message Log message
Message string `json:"message"`
// Timestamp Timestamp of when the log was created
Timestamp time.Time `json:"timestamp"`
}
// IngressPeer defines model for IngressPeer.
type IngressPeer struct {
AvailablePorts AvailablePorts `json:"available_ports"`
@@ -892,6 +1259,57 @@ type InstanceVersionInfo struct {
ManagementUpdateAvailable bool `json:"management_update_available"`
}
// IntegrationResponse Represents an event streaming integration.
type IntegrationResponse struct {
// AccountId The identifier of the account this integration belongs to.
AccountId *string `json:"account_id,omitempty"`
// Config Configuration for the integration. Sensitive keys (like API keys, secret keys) are masked with '****' in responses, as indicated by the GetIntegration handler logic.
Config *map[string]string `json:"config,omitempty"`
// CreatedAt Timestamp of when the integration was created.
CreatedAt *time.Time `json:"created_at,omitempty"`
// Enabled Whether the integration is currently active.
Enabled *bool `json:"enabled,omitempty"`
// Id The unique numeric identifier for the integration.
Id *int64 `json:"id,omitempty"`
// Platform The event streaming platform.
Platform *IntegrationResponsePlatform `json:"platform,omitempty"`
// UpdatedAt Timestamp of when the integration was last updated.
UpdatedAt *time.Time `json:"updated_at,omitempty"`
}
// IntegrationResponsePlatform The event streaming platform.
type IntegrationResponsePlatform string
// InvoicePDFResponse defines model for InvoicePDFResponse.
type InvoicePDFResponse struct {
// Url URL to redirect the user to invoice.
Url string `json:"url"`
}
// InvoiceResponse defines model for InvoiceResponse.
type InvoiceResponse struct {
// Id The Stripe invoice id
Id string `json:"id"`
// PeriodEnd The end date of the invoice period.
PeriodEnd time.Time `json:"period_end"`
// PeriodStart The start date of the invoice period.
PeriodStart time.Time `json:"period_start"`
// Type The invoice type
Type InvoiceResponseType `json:"type"`
}
// InvoiceResponseType The invoice type
type InvoiceResponseType string
// JobRequest defines model for JobRequest.
type JobRequest struct {
Workload WorkloadRequest `json:"workload"`
@@ -1797,6 +2215,15 @@ type PolicyUpdate struct {
SourcePostureChecks *[]string `json:"source_posture_checks,omitempty"`
}
// PortalResponse defines model for PortalResponse.
type PortalResponse struct {
// SessionId The unique identifier for the customer portal session.
SessionId string `json:"session_id"`
// Url URL to redirect the user to the customer portal.
Url string `json:"url"`
}
// PostureCheck defines model for PostureCheck.
type PostureCheck struct {
// Checks List of objects that perform the actual checks
@@ -1824,6 +2251,21 @@ type PostureCheckUpdate struct {
Name string `json:"name"`
}
// Price defines model for Price.
type Price struct {
// Currency Currency code for this price.
Currency string `json:"currency"`
// Price Price amount in minor units (e.g., cents).
Price int `json:"price"`
// PriceId Unique identifier for the price.
PriceId string `json:"price_id"`
// Unit Unit of measurement for this price (e.g., per user).
Unit string `json:"unit"`
}
// Process Describes the operational activity within a peer's system.
type Process struct {
// LinuxPath Path to the process executable file in a Linux operating system
@@ -1841,6 +2283,24 @@ type ProcessCheck struct {
Processes []Process `json:"processes"`
}
// Product defines model for Product.
type Product struct {
// Description Detailed description of the product.
Description string `json:"description"`
// Features List of features provided by the product.
Features []string `json:"features"`
// Free Indicates whether the product is free or not.
Free bool `json:"free"`
// Name Name of the product.
Name string `json:"name"`
// Prices List of prices for the product in different currencies
Prices []Price `json:"prices"`
}
// Resource defines model for Resource.
type Resource struct {
// Id ID of the resource
@@ -1950,6 +2410,66 @@ type RulePortRange struct {
Start int `json:"start"`
}
// ScimIntegration Represents a SCIM IDP integration
type ScimIntegration struct {
// AuthToken SCIM API token (full on creation, masked otherwise)
AuthToken string `json:"auth_token"`
// Enabled Indicates whether the integration is enabled
Enabled bool `json:"enabled"`
// GroupPrefixes List of start_with string patterns for groups to sync
GroupPrefixes []string `json:"group_prefixes"`
// Id The unique identifier for the integration
Id int64 `json:"id"`
// LastSyncedAt Timestamp of when the integration was last synced
LastSyncedAt time.Time `json:"last_synced_at"`
// Provider Name of the SCIM identity provider
Provider string `json:"provider"`
// UserGroupPrefixes List of start_with string patterns for groups which users to sync
UserGroupPrefixes []string `json:"user_group_prefixes"`
}
// ScimTokenResponse Response containing the regenerated SCIM token
type ScimTokenResponse struct {
// AuthToken The newly generated SCIM API token
AuthToken string `json:"auth_token"`
}
// SentinelOneMatchAttributes Attribute conditions to match when approving agents
type SentinelOneMatchAttributes struct {
// ActiveThreats The maximum allowed number of active threats on the agent
ActiveThreats *int `json:"active_threats,omitempty"`
// EncryptedApplications Whether disk encryption is enabled on the agent
EncryptedApplications *bool `json:"encrypted_applications,omitempty"`
// FirewallEnabled Whether the agent firewall is enabled
FirewallEnabled *bool `json:"firewall_enabled,omitempty"`
// Infected Whether the agent is currently flagged as infected
Infected *bool `json:"infected,omitempty"`
// IsActive Whether the agent has been recently active and reporting
IsActive *bool `json:"is_active,omitempty"`
// IsUpToDate Whether the agent is running the latest available version
IsUpToDate *bool `json:"is_up_to_date,omitempty"`
// NetworkStatus The current network connectivity status of the device
NetworkStatus *SentinelOneMatchAttributesNetworkStatus `json:"network_status,omitempty"`
// OperationalState The current operational state of the agent
OperationalState *string `json:"operational_state,omitempty"`
}
// SentinelOneMatchAttributesNetworkStatus The current network connectivity status of the device
type SentinelOneMatchAttributesNetworkStatus string
// SetupKey defines model for SetupKey.
type SetupKey struct {
// AllowExtraDnsLabels Allow extra DNS labels to be added to the peer
@@ -2121,6 +2641,117 @@ type SetupResponse struct {
UserId string `json:"user_id"`
}
// Subscription defines model for Subscription.
type Subscription struct {
// Active Indicates whether the subscription is active or not.
Active bool `json:"active"`
// Currency Currency code of the subscription.
Currency string `json:"currency"`
// Features List of features included in the subscription.
Features *[]string `json:"features,omitempty"`
// PlanTier The tier of the plan for the subscription.
PlanTier string `json:"plan_tier"`
// Price Price amount in minor units (e.g., cents).
Price int `json:"price"`
// PriceId Unique identifier for the price of the subscription.
PriceId string `json:"price_id"`
// Provider The provider of the subscription.
Provider string `json:"provider"`
// RemainingTrial The remaining time for the trial period, in seconds.
RemainingTrial *int `json:"remaining_trial,omitempty"`
// UpdatedAt The date and time when the subscription was last updated.
UpdatedAt time.Time `json:"updated_at"`
}
// TenantGroupResponse defines model for TenantGroupResponse.
type TenantGroupResponse struct {
// Id The Group ID
Id string `json:"id"`
// Role The Role name
Role string `json:"role"`
}
// TenantResponse defines model for TenantResponse.
type TenantResponse struct {
// ActivatedAt The date and time when the tenant was activated.
ActivatedAt *time.Time `json:"activated_at,omitempty"`
// CreatedAt The date and time when the tenant was created.
CreatedAt time.Time `json:"created_at"`
// DnsChallenge The DNS challenge to set in a TXT record
DnsChallenge string `json:"dns_challenge"`
// Domain The tenant account domain
Domain string `json:"domain"`
// Groups MSP users Groups that can access the Tenant and Roles to assume
Groups []TenantGroupResponse `json:"groups"`
// Id The updated MSP tenant account ID
Id string `json:"id"`
// InvitedAt The date and time when the existing tenant was invited.
InvitedAt *time.Time `json:"invited_at,omitempty"`
// Name The name for the MSP tenant
Name string `json:"name"`
// Status The status of the tenant
Status TenantResponseStatus `json:"status"`
// UpdatedAt The date and time when the tenant was last updated.
UpdatedAt time.Time `json:"updated_at"`
}
// TenantResponseStatus The status of the tenant
type TenantResponseStatus string
// UpdateScimIntegrationRequest Request payload for updating an SCIM IDP integration
type UpdateScimIntegrationRequest struct {
// Enabled Indicates whether the integration is enabled
Enabled *bool `json:"enabled,omitempty"`
// GroupPrefixes List of start_with string patterns for groups to sync
GroupPrefixes *[]string `json:"group_prefixes,omitempty"`
// UserGroupPrefixes List of start_with string patterns for groups which users to sync
UserGroupPrefixes *[]string `json:"user_group_prefixes,omitempty"`
}
// UpdateTenantRequest defines model for UpdateTenantRequest.
type UpdateTenantRequest struct {
// Groups MSP users Groups that can access the Tenant and Roles to assume
Groups []TenantGroupResponse `json:"groups"`
// Name The name for the MSP tenant
Name string `json:"name"`
}
// UsageStats defines model for UsageStats.
type UsageStats struct {
// ActivePeers Number of active peers.
ActivePeers int64 `json:"active_peers"`
// ActiveUsers Number of active users.
ActiveUsers int64 `json:"active_users"`
// TotalPeers Total number of peers.
TotalPeers int64 `json:"total_peers"`
// TotalUsers Total number of users.
TotalUsers int64 `json:"total_users"`
}
// User defines model for User.
type User struct {
// AutoGroups Group IDs to auto-assign to peers registered by this user
@@ -2407,6 +3038,66 @@ type GetApiGroupsParams struct {
Name *string `form:"name,omitempty" json:"name,omitempty"`
}
// PostApiIntegrationsBillingAwsMarketplaceActivateJSONBody defines parameters for PostApiIntegrationsBillingAwsMarketplaceActivate.
type PostApiIntegrationsBillingAwsMarketplaceActivateJSONBody struct {
// PlanTier The plan tier to activate the subscription for.
PlanTier string `json:"plan_tier"`
}
// PostApiIntegrationsBillingAwsMarketplaceEnrichJSONBody defines parameters for PostApiIntegrationsBillingAwsMarketplaceEnrich.
type PostApiIntegrationsBillingAwsMarketplaceEnrichJSONBody struct {
// AwsUserId The AWS user ID.
AwsUserId string `json:"aws_user_id"`
}
// PostApiIntegrationsBillingCheckoutJSONBody defines parameters for PostApiIntegrationsBillingCheckout.
type PostApiIntegrationsBillingCheckoutJSONBody struct {
// BaseURL The base URL for the redirect after checkout.
BaseURL string `json:"baseURL"`
// EnableTrial Enables a 14-day trial for the account.
EnableTrial *bool `json:"enableTrial,omitempty"`
// PriceID The Price ID for checkout.
PriceID string `json:"priceID"`
}
// GetApiIntegrationsBillingPortalParams defines parameters for GetApiIntegrationsBillingPortal.
type GetApiIntegrationsBillingPortalParams struct {
// BaseURL The base URL for the redirect after accessing the portal.
BaseURL string `form:"baseURL" json:"baseURL"`
}
// PutApiIntegrationsBillingSubscriptionJSONBody defines parameters for PutApiIntegrationsBillingSubscription.
type PutApiIntegrationsBillingSubscriptionJSONBody struct {
// PlanTier The plan tier to change the subscription to.
PlanTier *string `json:"plan_tier,omitempty"`
// PriceID The Price ID to change the subscription to.
PriceID *string `json:"priceID,omitempty"`
}
// PutApiIntegrationsMspTenantsIdInviteJSONBody defines parameters for PutApiIntegrationsMspTenantsIdInvite.
type PutApiIntegrationsMspTenantsIdInviteJSONBody struct {
// Value Accept or decline the invitation.
Value PutApiIntegrationsMspTenantsIdInviteJSONBodyValue `json:"value"`
}
// PutApiIntegrationsMspTenantsIdInviteJSONBodyValue defines parameters for PutApiIntegrationsMspTenantsIdInvite.
type PutApiIntegrationsMspTenantsIdInviteJSONBodyValue string
// PostApiIntegrationsMspTenantsIdSubscriptionJSONBody defines parameters for PostApiIntegrationsMspTenantsIdSubscription.
type PostApiIntegrationsMspTenantsIdSubscriptionJSONBody struct {
// PriceID The Price ID to change the subscription to.
PriceID string `json:"priceID"`
}
// PostApiIntegrationsMspTenantsIdUnlinkJSONBody defines parameters for PostApiIntegrationsMspTenantsIdUnlink.
type PostApiIntegrationsMspTenantsIdUnlinkJSONBody struct {
// Owner The new owners user ID.
Owner string `json:"owner"`
}
// GetApiPeersParams defines parameters for GetApiPeers.
type GetApiPeersParams struct {
// Name Filter peers by name
@@ -2452,6 +3143,12 @@ type PostApiDnsZonesZoneIdRecordsJSONRequestBody = DNSRecordRequest
// PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody defines body for PutApiDnsZonesZoneIdRecordsRecordId for application/json ContentType.
type PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody = DNSRecordRequest
// CreateIntegrationJSONRequestBody defines body for CreateIntegration for application/json ContentType.
type CreateIntegrationJSONRequestBody = CreateIntegrationRequest
// UpdateIntegrationJSONRequestBody defines body for UpdateIntegration for application/json ContentType.
type UpdateIntegrationJSONRequestBody = CreateIntegrationRequest
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
type PostApiGroupsJSONRequestBody = GroupRequest
@@ -2470,6 +3167,63 @@ type PostApiIngressPeersJSONRequestBody = IngressPeerCreateRequest
// PutApiIngressPeersIngressPeerIdJSONRequestBody defines body for PutApiIngressPeersIngressPeerId for application/json ContentType.
type PutApiIngressPeersIngressPeerIdJSONRequestBody = IngressPeerUpdateRequest
// PostApiIntegrationsBillingAwsMarketplaceActivateJSONRequestBody defines body for PostApiIntegrationsBillingAwsMarketplaceActivate for application/json ContentType.
type PostApiIntegrationsBillingAwsMarketplaceActivateJSONRequestBody PostApiIntegrationsBillingAwsMarketplaceActivateJSONBody
// PostApiIntegrationsBillingAwsMarketplaceEnrichJSONRequestBody defines body for PostApiIntegrationsBillingAwsMarketplaceEnrich for application/json ContentType.
type PostApiIntegrationsBillingAwsMarketplaceEnrichJSONRequestBody PostApiIntegrationsBillingAwsMarketplaceEnrichJSONBody
// PostApiIntegrationsBillingCheckoutJSONRequestBody defines body for PostApiIntegrationsBillingCheckout for application/json ContentType.
type PostApiIntegrationsBillingCheckoutJSONRequestBody PostApiIntegrationsBillingCheckoutJSONBody
// PutApiIntegrationsBillingSubscriptionJSONRequestBody defines body for PutApiIntegrationsBillingSubscription for application/json ContentType.
type PutApiIntegrationsBillingSubscriptionJSONRequestBody PutApiIntegrationsBillingSubscriptionJSONBody
// CreateFalconEDRIntegrationJSONRequestBody defines body for CreateFalconEDRIntegration for application/json ContentType.
type CreateFalconEDRIntegrationJSONRequestBody = EDRFalconRequest
// UpdateFalconEDRIntegrationJSONRequestBody defines body for UpdateFalconEDRIntegration for application/json ContentType.
type UpdateFalconEDRIntegrationJSONRequestBody = EDRFalconRequest
// CreateHuntressEDRIntegrationJSONRequestBody defines body for CreateHuntressEDRIntegration for application/json ContentType.
type CreateHuntressEDRIntegrationJSONRequestBody = EDRHuntressRequest
// UpdateHuntressEDRIntegrationJSONRequestBody defines body for UpdateHuntressEDRIntegration for application/json ContentType.
type UpdateHuntressEDRIntegrationJSONRequestBody = EDRHuntressRequest
// CreateEDRIntegrationJSONRequestBody defines body for CreateEDRIntegration for application/json ContentType.
type CreateEDRIntegrationJSONRequestBody = EDRIntuneRequest
// UpdateEDRIntegrationJSONRequestBody defines body for UpdateEDRIntegration for application/json ContentType.
type UpdateEDRIntegrationJSONRequestBody = EDRIntuneRequest
// CreateSentinelOneEDRIntegrationJSONRequestBody defines body for CreateSentinelOneEDRIntegration for application/json ContentType.
type CreateSentinelOneEDRIntegrationJSONRequestBody = EDRSentinelOneRequest
// UpdateSentinelOneEDRIntegrationJSONRequestBody defines body for UpdateSentinelOneEDRIntegration for application/json ContentType.
type UpdateSentinelOneEDRIntegrationJSONRequestBody = EDRSentinelOneRequest
// PostApiIntegrationsMspTenantsJSONRequestBody defines body for PostApiIntegrationsMspTenants for application/json ContentType.
type PostApiIntegrationsMspTenantsJSONRequestBody = CreateTenantRequest
// PutApiIntegrationsMspTenantsIdJSONRequestBody defines body for PutApiIntegrationsMspTenantsId for application/json ContentType.
type PutApiIntegrationsMspTenantsIdJSONRequestBody = UpdateTenantRequest
// PutApiIntegrationsMspTenantsIdInviteJSONRequestBody defines body for PutApiIntegrationsMspTenantsIdInvite for application/json ContentType.
type PutApiIntegrationsMspTenantsIdInviteJSONRequestBody PutApiIntegrationsMspTenantsIdInviteJSONBody
// PostApiIntegrationsMspTenantsIdSubscriptionJSONRequestBody defines body for PostApiIntegrationsMspTenantsIdSubscription for application/json ContentType.
type PostApiIntegrationsMspTenantsIdSubscriptionJSONRequestBody PostApiIntegrationsMspTenantsIdSubscriptionJSONBody
// PostApiIntegrationsMspTenantsIdUnlinkJSONRequestBody defines body for PostApiIntegrationsMspTenantsIdUnlink for application/json ContentType.
type PostApiIntegrationsMspTenantsIdUnlinkJSONRequestBody PostApiIntegrationsMspTenantsIdUnlinkJSONBody
// CreateSCIMIntegrationJSONRequestBody defines body for CreateSCIMIntegration for application/json ContentType.
type CreateSCIMIntegrationJSONRequestBody = CreateScimIntegrationRequest
// UpdateSCIMIntegrationJSONRequestBody defines body for UpdateSCIMIntegration for application/json ContentType.
type UpdateSCIMIntegrationJSONRequestBody = UpdateScimIntegrationRequest
// PostApiNetworksJSONRequestBody defines body for PostApiNetworks for application/json ContentType.
type PostApiNetworksJSONRequestBody = NetworkRequest

View File

@@ -40,7 +40,6 @@ func Execute() error {
func init() {
stopCh = make(chan int)
defaultLogFile = "/var/log/netbird/signal.log"
defaultSignalSSLDir = "/var/lib/netbird/"
if runtime.GOOS == "windows" {
defaultLogFile = os.Getenv("PROGRAMDATA") + "\\Netbird\\" + "signal.log"

View File

@@ -18,7 +18,7 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/shared/metrics"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/signal/proto"
@@ -38,13 +38,13 @@ import (
const legacyGRPCPort = 10000
var (
signalPort int
metricsPort int
signalLetsencryptDomain string
signalSSLDir string
defaultSignalSSLDir string
signalCertFile string
signalCertKey string
signalPort int
metricsPort int
signalLetsencryptDomain string
signalLetsencryptEmail string
signalLetsencryptDataDir string
signalCertFile string
signalCertKey string
signalKaep = grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
MinTime: 5 * time.Second,
@@ -216,7 +216,7 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config
}
if signalLetsencryptDomain != "" {
certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain)
certManager, err = encryption.CreateCertManager(signalLetsencryptDataDir, signalLetsencryptDomain)
if err != nil {
return nil, certManager, nil, err
}
@@ -326,9 +326,11 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
func init() {
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
runCmd.Flags().IntVar(&metricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
runCmd.Flags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
runCmd.PersistentFlags().StringVar(&signalLetsencryptDataDir, "letsencrypt-data-dir", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
runCmd.PersistentFlags().StringVar(&signalLetsencryptDataDir, "ssl-dir", "", "server ssl directory location. *Required only for Let's Encrypt certificates. Deprecated: use --letsencrypt-data-dir")
runCmd.PersistentFlags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.PersistentFlags().StringVar(&signalLetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
runCmd.PersistentFlags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
runCmd.PersistentFlags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
setFlagsFromEnvVars(runCmd)
}

View File

@@ -24,15 +24,19 @@ type AppMetrics struct {
MessageSize metric.Int64Histogram
}
func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
activePeers, err := meter.Int64UpDownCounter("active_peers",
func NewAppMetrics(meter metric.Meter, prefix ...string) (*AppMetrics, error) {
p := ""
if len(prefix) > 0 {
p = prefix[0]
}
activePeers, err := meter.Int64UpDownCounter(p+"active_peers",
metric.WithDescription("Number of active connected peers"),
)
if err != nil {
return nil, err
}
peerConnectionDuration, err := meter.Int64Histogram("peer_connection_duration_seconds",
peerConnectionDuration, err := meter.Int64Histogram(p+"peer_connection_duration_seconds",
metric.WithExplicitBucketBoundaries(getPeerConnectionDurationBucketBoundaries()...),
metric.WithDescription("Duration of how long a peer was connected"),
)
@@ -40,28 +44,28 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
return nil, err
}
registrations, err := meter.Int64Counter("registrations_total",
registrations, err := meter.Int64Counter(p+"registrations_total",
metric.WithDescription("Total number of peer registrations"),
)
if err != nil {
return nil, err
}
deregistrations, err := meter.Int64Counter("deregistrations_total",
deregistrations, err := meter.Int64Counter(p+"deregistrations_total",
metric.WithDescription("Total number of peer deregistrations"),
)
if err != nil {
return nil, err
}
registrationFailures, err := meter.Int64Counter("registration_failures_total",
registrationFailures, err := meter.Int64Counter(p+"registration_failures_total",
metric.WithDescription("Total number of peer registration failures"),
)
if err != nil {
return nil, err
}
registrationDelay, err := meter.Float64Histogram("registration_delay_milliseconds",
registrationDelay, err := meter.Float64Histogram(p+"registration_delay_milliseconds",
metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...),
metric.WithDescription("Duration of how long it takes to register a peer"),
)
@@ -69,7 +73,7 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
return nil, err
}
getRegistrationDelay, err := meter.Float64Histogram("get_registration_delay_milliseconds",
getRegistrationDelay, err := meter.Float64Histogram(p+"get_registration_delay_milliseconds",
metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...),
metric.WithDescription("Duration of how long it takes to load a connection from the registry"),
)
@@ -77,21 +81,21 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
return nil, err
}
messagesForwarded, err := meter.Int64Counter("messages_forwarded_total",
messagesForwarded, err := meter.Int64Counter(p+"messages_forwarded_total",
metric.WithDescription("Total number of messages forwarded to peers"),
)
if err != nil {
return nil, err
}
messageForwardFailures, err := meter.Int64Counter("message_forward_failures_total",
messageForwardFailures, err := meter.Int64Counter(p+"message_forward_failures_total",
metric.WithDescription("Total number of message forwarding failures"),
)
if err != nil {
return nil, err
}
messageForwardLatency, err := meter.Float64Histogram("message_forward_latency_milliseconds",
messageForwardLatency, err := meter.Float64Histogram(p+"message_forward_latency_milliseconds",
metric.WithExplicitBucketBoundaries(getStandardBucketBoundaries()...),
metric.WithDescription("Duration of how long it takes to forward a message to a peer"),
)
@@ -100,7 +104,7 @@ func NewAppMetrics(meter metric.Meter) (*AppMetrics, error) {
}
messageSize, err := meter.Int64Histogram(
"message.size.bytes",
p+"message.size.bytes",
metric.WithUnit("bytes"),
metric.WithExplicitBucketBoundaries(getMessageSizeBucketBoundaries()...),
metric.WithDescription("Records the size of each message sent"),

View File

@@ -62,8 +62,8 @@ type Server struct {
}
// NewServer creates a new Signal server
func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
appMetrics, err := metrics.NewAppMetrics(meter)
func NewServer(ctx context.Context, meter metric.Meter, metricsPrefix ...string) (*Server, error) {
appMetrics, err := metrics.NewAppMetrics(meter, metricsPrefix...)
if err != nil {
return nil, fmt.Errorf("creating app metrics: %v", err)
}

View File

@@ -48,7 +48,7 @@ func NewServer(conns []*net.UDPConn, logLevel string) *Server {
// Use the formatter package to set up formatter, ReportCaller, and context hook
formatter.SetTextFormatter(stunLogger)
logger := stunLogger.WithField("component", "stun-server")
logger := stunLogger.WithField("component", "stun")
logger.Infof("STUN server log level set to: %s", level.String())
return &Server{