Compare commits

...

14 Commits

Author SHA1 Message Date
mlsmaycon
279e96e6b1 add disk encryption check 2026-01-17 19:56:50 +01:00
Maycon Santos
245481f33b [client] fix: client/Dockerfile to reduce vulnerabilities (#5119)
The following vulnerabilities are fixed with an upgrade:
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091698
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091701
- https://snyk.io/vuln/SNYK-ALPINE322-BUSYBOX-14091701

Co-authored-by: snyk-bot <snyk-bot@snyk.io>
2026-01-16 18:05:41 +01:00
shuuri-labs
b352ab84c0 Feat/quickstart reverse proxy assistant (#5100)
* add external reverse proxy config steps to quickstart script

* remove generated files

* - Remove 'press enter' prompt from post-traefik config since traefik requires no manual config
- Improve npm flow (ask users for docker network, user container names in config)

* fixes for npm flow

* nginx flow fixes

* caddy flow fixes

* Consolidate NPM_NETWORK, NGINX_NETWORK, CADDY_NETWORK into single
EXTERNAL_PROXY_NETWORK variable. Add read_proxy_docker_network()
function that prompts for Docker network for options 2-4 (Nginx,
NPM, Caddy). Generated configs now use container names when a
Docker network is specified.

* fix https for traefik

* fix sonar code smells

* fix sonar smell (add return to render_dashboard_env)

* added tls instructions to nginx flow

* removed unused bind_addr variable from quickstart.sh

* Refactor getting-started.sh for improved maintainability

Break down large functions into focused, single-responsibility components:
- Split init_environment() into 6 initialization functions
- Split print_post_setup_instructions() into 6 proxy-specific functions
- Add section headers for better code organization
- Fix 3 code smell issues (unused bind_addr variables)
- Add TLS certificate documentation for Nginx
- Link reverse proxy names to docs sections

Reduces largest function from 205 to ~90 lines while maintaining
single-file distribution. No functional changes.

* - Remove duplicate network display logic in Traefik instructions
- Use upstream_host instead of bind_addr for NPM forward hostname
- Use upstream_host instead of bind_addr in manual proxy route examples
- Prevents displaying invalid 0.0.0.0 as connection target in setup instructions

* add wait_management_direct to caddy flow to ensure script waits until containers are running/passing healthchecks before reporting 'done!'
2026-01-16 17:42:28 +01:00
ressys1978
3ce5d6a4f8 [management] Add idp timeout env variable (#4647)
Introduced the NETBIRD_IDP_TIMEOUT environment variable to the management service. This allows configuring a timeout for supported IDPs. If the variable is unset or contains an invalid value, a default timeout of 10 seconds is used as a fallback.

This is needed for larger IDP environments where 10s is just not enough time.
2026-01-16 16:23:37 +01:00
Misha Bragin
4c2eb2af73 [management] Skip email_verified if not present (#5118) 2026-01-16 16:01:39 +01:00
Misha Bragin
daf1449174 [client] Remove duplicate audiences check (#5117) 2026-01-16 14:25:02 +02:00
Misha Bragin
1ff7abe909 [management, client] Fix SSH server audience validator (#5105)
* **New Features**
  * SSH server JWT validation now accepts multiple audiences with backward-compatible handling of the previous single-audience setting and a guard ensuring at least one audience is configured.
* **Tests**
  * Test suites updated and new tests added to cover multiple-audience scenarios and compatibility with existing behavior.
* **Other**
  * Startup logging enhanced to report configured audiences for JWT auth.
2026-01-16 12:28:17 +01:00
Bethuel Mmbaga
067c77e49e [management] Add custom dns zones (#4849) 2026-01-16 12:12:05 +03:00
Maycon Santos
291e640b28 [client] Change priority between local and dns route handlers (#5106)
* Change priority between local and dns route handlers

* update priority tests
2026-01-15 17:30:10 +01:00
Pascal Fischer
efb954b7d6 [management] adapt ratelimiting (#5080) 2026-01-15 16:39:14 +01:00
Vlad
cac9326d3d [management] fetch all users data from external cache in one request (#5104)
---------

Co-authored-by: pascal <pascal@netbird.io>
2026-01-14 17:09:17 +01:00
Viktor Liu
520d9c66cf [client] Fix netstack upstream dns and add wasm debug methods (#4648) 2026-01-14 13:56:16 +01:00
Misha Bragin
ff10498a8b Feature/embedded STUN (#5062) 2026-01-14 13:13:30 +01:00
Zoltan Papp
00b747ad5d Handle fallback for invalid loginuid in ui-post-install.sh. (#5099) 2026-01-14 09:53:14 +01:00
111 changed files with 9558 additions and 1194 deletions

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.2
FROM alpine:3.23.2
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \

View File

@@ -314,9 +314,8 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
)
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
statusOutputString = overview.FullDetailSummary()
}
return statusOutputString
}

View File

@@ -103,13 +103,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
var statusOutputString string
switch {
case detailFlag:
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
statusOutputString = outputInformationHolder.FullDetailSummary()
case jsonFlag:
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
statusOutputString, err = outputInformationHolder.JSON()
case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
statusOutputString, err = outputInformationHolder.YAML()
default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
}
if err != nil {

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
var (
@@ -38,6 +39,7 @@ type Client struct {
setupKey string
jwtToken string
connect *internal.ConnectClient
recorder *peer.Status
}
// Options configures a new Client.
@@ -161,11 +163,17 @@ func New(opts Options) (*Client, error) {
func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cancel != nil {
if c.connect != nil {
return ErrClientAlreadyStarted
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
defer func() {
if c.connect == nil {
cancel()
}
}()
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
@@ -173,7 +181,9 @@ func (c *Client) Start(startCtx context.Context) error {
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
@@ -197,6 +207,7 @@ func (c *Client) Start(startCtx context.Context) error {
}
c.connect = client
c.cancel = cancel
return nil
}
@@ -211,17 +222,23 @@ func (c *Client) Stop(ctx context.Context) error {
return ErrClientNotStarted
}
if c.cancel != nil {
c.cancel()
c.cancel = nil
}
done := make(chan error, 1)
connect := c.connect
go func() {
done <- c.connect.Stop()
done <- connect.Stop()
}()
select {
case <-ctx.Done():
c.cancel = nil
c.connect = nil
return ctx.Err()
case err := <-done:
c.cancel = nil
c.connect = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
@@ -315,6 +332,62 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
_ = engine.RunHealthProbes(false)
}
}
return recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
syncResp, err := engine.GetLatestSyncResponse()
if err != nil {
return nil, fmt.Errorf("get sync response: %w", err)
}
return syncResp, nil
}
// SetLogLevel sets the logging level for the client and its components.
func (c *Client) SetLogLevel(levelStr string) error {
level, err := logrus.ParseLevel(levelStr)
if err != nil {
return fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
c.mu.Lock()
connect := c.connect
c.mu.Unlock()
if connect != nil {
connect.SetLogLevel(level)
}
return nil
}
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.

View File

@@ -420,6 +420,19 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
return syncResponse, nil
}
// SetLogLevel sets the log level for the firewall manager if the engine is running.
func (c *ConnectClient) SetLogLevel(level log.Level) {
engine := c.Engine()
if engine == nil {
return
}
fwManager := engine.GetFirewallManager()
if fwManager != nil {
fwManager.SetLogLevel(level)
}
}
// Status returns the current client status
func (c *ConnectClient) Status() StatusType {
if c == nil {

View File

@@ -16,8 +16,8 @@ import (
const (
PriorityMgmtCache = 150
PriorityLocal = 100
PriorityDNSRoute = 75
PriorityDNSRoute = 100
PriorityLocal = 75
PriorityUpstream = 50
PriorityDefault = 1
PriorityFallback = -100

View File

@@ -631,9 +631,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
nbdns.RootZone,
@@ -743,9 +741,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
domainGroup.domain,
@@ -926,9 +922,7 @@ func (s *DefaultServer) addHostRootZone() {
handler, err := newUpstreamResolver(
s.ctx,
s.wgInterface.Name(),
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.wgInterface,
s.statusRecorder,
s.hostsDNSHolder,
nbdns.RootZone,

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
@@ -81,6 +82,10 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
return configurer.WGStats{}, nil
}
func (w *mocWGIface) GetNet() *netstack.Net {
return nil
}
var zoneRecords = []nbdns.SimpleRecord{
{
Name: "peera.netbird.cloud",
@@ -2047,7 +2052,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")

View File

@@ -18,6 +18,7 @@ import (
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
@@ -418,6 +419,56 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil
}
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
if err != nil {
return nil, err
}
// If response is truncated, retry with TCP
if reply != nil && reply.MsgHdr.Truncated {
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
}
return reply, nil
}
func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) {
conn, err := nsNet.DialContext(ctx, network, upstream)
if err != nil {
return nil, fmt.Errorf("with %s: %w", network, err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close DNS connection: %v", err)
}
}()
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
}
dnsConn := &dns.Conn{Conn: conn}
if err := dnsConn.WriteMsg(r); err != nil {
return nil, fmt.Errorf("write %s message: %w", network, err)
}
reply, err := dnsConn.ReadMsg()
if err != nil {
return nil, fmt.Errorf("read %s message: %w", network, err)
}
return reply, nil
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected

View File

@@ -23,9 +23,7 @@ type upstreamResolver struct {
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
func newUpstreamResolver(
ctx context.Context,
_ string,
_ netip.Addr,
_ netip.Prefix,
_ WGIface,
statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder,
domain string,

View File

@@ -5,22 +5,23 @@ package dns
import (
"context"
"net/netip"
"runtime"
"time"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/internal/peer"
)
type upstreamResolver struct {
*upstreamResolverBase
nsNet *netstack.Net
}
func newUpstreamResolver(
ctx context.Context,
_ string,
_ netip.Addr,
_ netip.Prefix,
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
@@ -28,12 +29,23 @@ func newUpstreamResolver(
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase,
nsNet: wgIface.GetNet(),
}
upstreamResolverBase.upstreamClient = nonIOS
return nonIOS, nil
}
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
// TODO: Check if upstream DNS server is routed through a peer before using netstack.
// Similar to iOS logic, we should determine if the DNS server is reachable directly
// or needs to go through the tunnel, and only use netstack when necessary.
// For now, only use netstack on JS platform where direct access is not possible.
if u.nsNet != nil && runtime.GOOS == "js" {
start := time.Now()
reply, err := ExchangeWithNetstack(ctx, u.nsNet, r, upstream)
return reply, time.Since(start), err
}
client := &dns.Client{
Timeout: ClientTimeout,
}

View File

@@ -26,9 +26,7 @@ type upstreamResolverIOS struct {
func newUpstreamResolver(
ctx context.Context,
interfaceName string,
ip netip.Addr,
net netip.Prefix,
wgIface WGIface,
statusRecorder *peer.Status,
_ *hostsDNSHolder,
domain string,
@@ -37,9 +35,9 @@ func newUpstreamResolver(
ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase,
lIP: ip,
lNet: net,
interfaceName: interfaceName,
lIP: wgIface.Address().IP,
lNet: wgIface.Address().Network,
interfaceName: wgIface.Name(),
}
ios.upstreamClient = ios

View File

@@ -2,13 +2,17 @@ package dns
import (
"context"
"net"
"net/netip"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/test"
)
@@ -58,7 +62,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver, _ := newUpstreamResolver(ctx, &mockNetstackProvider{}, nil, nil, ".")
// Convert test servers to netip.AddrPort
var servers []netip.AddrPort
for _, server := range testCase.InputServers {
@@ -112,6 +116,19 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
}
}
type mockNetstackProvider struct{}
func (m *mockNetstackProvider) Name() string { return "mock" }
func (m *mockNetstackProvider) Address() wgaddr.Address { return wgaddr.Address{} }
func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil }
func (m *mockNetstackProvider) IsUserspaceBind() bool { return false }
func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil }
func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil }
func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil }
func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
}
type mockUpstreamResolver struct {
r *dns.Msg
rtt time.Duration

View File

@@ -5,6 +5,8 @@ package dns
import (
"net"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -17,4 +19,5 @@ type WGIface interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
}

View File

@@ -1,6 +1,8 @@
package dns
import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -12,5 +14,6 @@ type WGIface interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
GetInterfaceGUIDString() (string, error)
}

View File

@@ -1748,6 +1748,10 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
}
e.syncMsgMux.Unlock()
// Skip STUN/TURN probing for JS/WASM as it's not available
relayHealthy := true
if runtime.GOOS != "js" {
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
@@ -1756,7 +1760,6 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
}
e.statusRecorder.UpdateRelayStates(results)
relayHealthy := true
for _, res := range results {
if res.Err != nil {
relayHealthy = false
@@ -1764,6 +1767,7 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
}
}
log.Debugf("relay health check: healthy=%t", relayHealthy)
}
allHealthy := signalHealthy && managementHealthy && relayHealthy
log.Debugf("all health checks completed: healthy=%t", allHealthy)

View File

@@ -72,9 +72,16 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
}
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
audiences := protoJWT.GetAudiences()
if len(audiences) == 0 && protoJWT.GetAudience() != "" {
audiences = []string{protoJWT.GetAudience()}
}
log.Debugf("starting SSH server with JWT authentication: audiences=%v", audiences)
jwtConfig := &sshserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audience: protoJWT.GetAudience(),
Audiences: audiences,
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
}

View File

@@ -14,6 +14,7 @@ import (
"golang.org/x/exp/maps"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -158,6 +159,7 @@ type FullStatus struct {
NSGroupStates []NSGroupState
NumOfForwardingRules int
LazyConnectionEnabled bool
Events []*proto.SystemEvent
}
type StatusChangeSubscription struct {
@@ -981,6 +983,7 @@ func (d *Status) GetFullStatus() FullStatus {
}
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
fullStatus.Events = d.GetEventHistory()
return fullStatus
}
@@ -1181,3 +1184,97 @@ type EventSubscription struct {
func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
return s.events
}
// ToProto converts FullStatus to proto.FullStatus.
func (fs FullStatus) ToProto() *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fs.ManagementState.URL
pbFullStatus.ManagementState.Connected = fs.ManagementState.Connected
if err := fs.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fs.SignalState.URL
pbFullStatus.SignalState.Connected = fs.SignalState.Connected
if err := fs.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fs.LazyConnectionEnabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fs.LocalPeerState.Routes)
for _, peerState := range fs.Peers {
networks := maps.Keys(peerState.GetRoutes())
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: networks,
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fs.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fs.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
pbFullStatus.Events = fs.Events
return &pbFullStatus
}

View File

@@ -17,13 +17,13 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
iface "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -38,11 +38,6 @@ type internalDNATer interface {
AddInternalDNATMapping(netip.Addr, netip.Addr) error
}
type wgInterface interface {
Name() string
Address() wgaddr.Address
}
type DnsInterceptor struct {
mu sync.RWMutex
route *route.Route
@@ -52,7 +47,7 @@ type DnsInterceptor struct {
dnsServer nbdns.Server
currentPeerKey string
interceptedDomains domainMap
wgInterface wgInterface
wgInterface iface.WGIface
peerStore *peerstore.Store
firewall firewall.Manager
fakeIPManager *fakeip.Manager
@@ -250,12 +245,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return
}
if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true
}
@@ -264,20 +253,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
startTime := time.Now()
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
reply := d.queryUpstreamDNS(ctx, w, r, upstream, upstreamIP, peerKey, logger)
if reply == nil {
return
}
@@ -586,6 +563,44 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
return
}
// queryUpstreamDNS queries the upstream DNS server using netstack if available, otherwise uses regular client.
// Returns the DNS reply on success, or nil on error (error responses are written internally).
func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream string, upstreamIP netip.Addr, peerKey string, logger *log.Entry) *dns.Msg {
startTime := time.Now()
nsNet := d.wgInterface.GetNet()
var reply *dns.Msg
var err error
if nsNet != nil {
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
} else {
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if clientErr != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
return nil
}
reply, _, err = nbdns.ExchangeWithFallback(ctx, client, r, upstream)
}
if err == nil {
return reply
}
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
return nil
}
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
if d.statusRecorder == nil {
return ""

View File

@@ -4,6 +4,8 @@ import (
"net"
"net/netip"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -18,4 +20,5 @@ type wgIfaceBase interface {
IsUserspaceBind() bool
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
}

View File

@@ -173,20 +173,9 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
log.SetLevel(level)
if s.connectClient == nil {
return nil, fmt.Errorf("connect client not initialized")
if s.connectClient != nil {
s.connectClient.SetLogLevel(level)
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("engine not initialized")
}
fwManager := engine.GetFirewallManager()
if fwManager == nil {
return nil, fmt.Errorf("firewall manager not initialized")
}
fwManager.SetLogLevel(level)
log.Infof("Log level set to %s", level.String())

View File

@@ -1,8 +1,6 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto"
@@ -29,8 +27,3 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo
}
}
}
func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) {
events := s.statusRecorder.GetEventHistory()
return &proto.GetEventsResponse{Events: events}, nil
}

View File

@@ -13,15 +13,12 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -1067,11 +1064,9 @@ func (s *Server) Status(
if msg.GetFullPeerStatus {
s.runProbes(msg.ShouldRunProbes)
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState()
statusResponse.FullStatus = pbFullStatus
}
@@ -1600,94 +1595,6 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
return defaultDuration
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
if err := fullStatus.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
if err := fullStatus.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
// sendTerminalNotification sends a terminal notification message
// to inform the user that the NetBird connection session has expired.
func sendTerminalNotification() error {

View File

@@ -132,7 +132,7 @@ func TestSSHProxy_Connect(t *testing.T) {
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}

View File

@@ -43,7 +43,7 @@ func TestJWTEnforcement(t *testing.T) {
t.Run("blocks_without_jwt", func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
Audiences: []string{"test-audience"},
KeysLocation: "test-keys",
}
serverConfig := &Config{
@@ -202,7 +202,7 @@ func TestJWTDetection(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
}
serverConfig := &Config{
@@ -329,7 +329,7 @@ func TestJWTFailClose(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
MaxTokenAge: 3600,
}
@@ -567,7 +567,7 @@ func TestJWTAuthentication(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: []string{audience},
KeysLocation: jwksURL,
}
serverConfig := &Config{
@@ -646,3 +646,108 @@ func TestJWTAuthentication(t *testing.T) {
})
}
}
// TestJWTMultipleAudiences tests JWT validation with multiple audiences (dashboard and CLI).
func TestJWTMultipleAudiences(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT multiple audiences tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
dashboardAudience = "dashboard-audience"
cliAudience = "cli-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
audience string
wantAuthOK bool
}{
{
name: "accepts_dashboard_audience",
audience: dashboardAudience,
wantAuthOK: true,
},
{
name: "accepts_cli_audience",
audience: cliAudience,
wantAuthOK: true,
},
{
name: "rejects_unknown_audience",
audience: "unknown-audience",
wantAuthOK: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audiences: []string{dashboardAudience, cliAudience},
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
testUserHash, err := sshuserhash.HashUserID("test-user")
require.NoError(t, err)
currentUser := testutil.GetTestUsername(t)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
currentUser: {0},
},
}
server.UpdateSSHAuth(authConfig)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
token := generateValidJWT(t, privateKey, issuer, tc.audience)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{
cryptossh.Password(token),
},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if tc.wantAuthOK {
require.NoError(t, err, "JWT authentication should succeed for audience %s", tc.audience)
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
err = session.Shell()
require.NoError(t, err, "Shell should work with valid audience")
} else {
assert.Error(t, err, "JWT authentication should fail for unknown audience")
}
})
}
}

View File

@@ -176,9 +176,9 @@ type Server struct {
type JWTConfig struct {
Issuer string
Audience string
KeysLocation string
MaxTokenAge int64
Audiences []string
}
// Config contains all SSH server configuration options
@@ -427,18 +427,21 @@ func (s *Server) ensureJWTValidator() error {
return fmt.Errorf("JWT config not set")
}
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
if len(config.Audiences) == 0 {
return fmt.Errorf("JWT config has no audiences configured")
}
log.Debugf("Initializing JWT validator (issuer: %s, audiences: %v)", config.Issuer, config.Audiences)
validator := jwt.NewValidator(
config.Issuer,
[]string{config.Audience},
config.Audiences,
config.KeysLocation,
true,
)
// Use custom userIDClaim from authorizer if available
extractorOptions := []jwt.ClaimsExtractorOption{
jwt.WithAudience(config.Audience),
jwt.WithAudience(config.Audiences[0]),
}
if authorizer.GetUserIDClaim() != "" {
extractorOptions = append(extractorOptions, jwt.WithUserIDClaim(authorizer.GetUserIDClaim()))
@@ -475,8 +478,8 @@ func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
if err != nil {
if jwtConfig != nil {
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
return nil, fmt.Errorf("validate token (expected issuer=%s, audiences=%v, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audiences, claims["iss"], claims["aud"], err)
}
}
return nil, fmt.Errorf("validate token: %w", err)

View File

@@ -325,61 +325,64 @@ func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) {
}
}
func ParseToJSON(overview OutputOverview) (string, error) {
jsonBytes, err := json.Marshal(overview)
// JSON returns the status overview as a JSON string.
func (o *OutputOverview) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
func ParseToYAML(overview OutputOverview) (string, error) {
yamlBytes, err := yaml.Marshal(overview)
// YAML returns the status overview as a YAML string.
func (o *OutputOverview) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
// GeneralSummary returns a general summary of the status overview.
func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
var managementConnString string
if overview.ManagementState.Connected {
if o.ManagementState.Connected {
managementConnString = "Connected"
if showURL {
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
managementConnString = fmt.Sprintf("%s to %s", managementConnString, o.ManagementState.URL)
}
} else {
managementConnString = "Disconnected"
if overview.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
if o.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, o.ManagementState.Error)
}
}
var signalConnString string
if overview.SignalState.Connected {
if o.SignalState.Connected {
signalConnString = "Connected"
if showURL {
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
signalConnString = fmt.Sprintf("%s to %s", signalConnString, o.SignalState.URL)
}
} else {
signalConnString = "Disconnected"
if overview.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
if o.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, o.SignalState.Error)
}
}
interfaceTypeString := "Userspace"
interfaceIP := overview.IP
if overview.KernelInterface {
interfaceIP := o.IP
if o.KernelInterface {
interfaceTypeString = "Kernel"
} else if overview.IP == "" {
} else if o.IP == "" {
interfaceTypeString = "N/A"
interfaceIP = "N/A"
}
var relaysString string
if showRelays {
for _, relay := range overview.Relays.Details {
for _, relay := range o.Relays.Details {
available := "Available"
reason := ""
@@ -395,18 +398,18 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
relaysString = fmt.Sprintf("%d/%d Available", o.Relays.Available, o.Relays.Total)
}
networks := "-"
if len(overview.Networks) > 0 {
sort.Strings(overview.Networks)
networks = strings.Join(overview.Networks, ", ")
if len(o.Networks) > 0 {
sort.Strings(o.Networks)
networks = strings.Join(o.Networks, ", ")
}
var dnsServersString string
if showNameServers {
for _, nsServerGroup := range overview.NSServerGroups {
for _, nsServerGroup := range o.NSServerGroups {
enabled := "Available"
if !nsServerGroup.Enabled {
enabled = "Unavailable"
@@ -430,25 +433,25 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
)
}
} else {
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(o.NSServerGroups), len(o.NSServerGroups))
}
rosenpassEnabledStatus := "false"
if overview.RosenpassEnabled {
if o.RosenpassEnabled {
rosenpassEnabledStatus = "true"
if overview.RosenpassPermissive {
if o.RosenpassPermissive {
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
}
}
lazyConnectionEnabledStatus := "false"
if overview.LazyConnectionEnabled {
if o.LazyConnectionEnabled {
lazyConnectionEnabledStatus = "true"
}
sshServerStatus := "Disabled"
if overview.SSHServerState.Enabled {
sessionCount := len(overview.SSHServerState.Sessions)
if o.SSHServerState.Enabled {
sessionCount := len(o.SSHServerState.Sessions)
if sessionCount > 0 {
sessionWord := "session"
if sessionCount > 1 {
@@ -460,7 +463,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
}
if showSSHSessions && sessionCount > 0 {
for _, session := range overview.SSHServerState.Sessions {
for _, session := range o.SSHServerState.Sessions {
var sessionDisplay string
if session.JWTUsername != "" {
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
@@ -484,7 +487,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
@@ -512,30 +515,31 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
"Forwarding rules: %d\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
o.DaemonVersion,
version.NetbirdVersion(),
overview.ProfileName,
o.ProfileName,
managementConnString,
signalConnString,
relaysString,
dnsServersString,
domain.Domain(overview.FQDN).SafeString(),
domain.Domain(o.FQDN).SafeString(),
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
lazyConnectionEnabledStatus,
sshServerStatus,
networks,
overview.NumberOfForwardingRules,
o.NumberOfForwardingRules,
peersCountString,
)
return summary
}
func ParseToFullDetailSummary(overview OutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
parsedEventsString := parseEvents(overview.Events)
summary := ParseGeneralSummary(overview, true, true, true, true)
// FullDetailSummary returns a full detailed summary with peer details and events.
func (o *OutputOverview) FullDetailSummary() string {
parsedPeersString := parsePeers(o.Peers, o.RosenpassEnabled, o.RosenpassPermissive)
parsedEventsString := parseEvents(o.Events)
summary := o.GeneralSummary(true, true, true, true)
return fmt.Sprintf(
"Peers detail:"+

View File

@@ -268,7 +268,7 @@ func TestSortingOfPeers(t *testing.T) {
}
func TestParsingToJSON(t *testing.T) {
jsonString, _ := ParseToJSON(overview)
jsonString, _ := overview.JSON()
//@formatter:off
expectedJSONString := `
@@ -404,7 +404,7 @@ func TestParsingToJSON(t *testing.T) {
}
func TestParsingToYAML(t *testing.T) {
yaml, _ := ParseToYAML(overview)
yaml, _ := overview.YAML()
expectedYAML :=
`peers:
@@ -511,7 +511,7 @@ func TestParsingToDetail(t *testing.T) {
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := ParseToFullDetailSummary(overview)
detail := overview.FullDetailSummary()
expectedDetail := fmt.Sprintf(
`Peers detail:
@@ -575,7 +575,7 @@ Peers count: 2/2 Connected
}
func TestParsingToShortVersion(t *testing.T) {
shortVersion := ParseGeneralSummary(overview, false, false, false, false)
shortVersion := overview.GeneralSummary(false, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1

View File

@@ -0,0 +1,22 @@
package system
// DiskEncryptionVolume represents encryption status of a single volume.
type DiskEncryptionVolume struct {
Path string
Encrypted bool
}
// DiskEncryptionInfo holds disk encryption detection results.
type DiskEncryptionInfo struct {
Volumes []DiskEncryptionVolume
}
// IsEncrypted returns true if the volume at the given path is encrypted.
func (d DiskEncryptionInfo) IsEncrypted(path string) bool {
for _, v := range d.Volumes {
if v.Path == path {
return v.Encrypted
}
}
return false
}

View File

@@ -0,0 +1,35 @@
//go:build darwin && !ios
package system
import (
"context"
"os/exec"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// detectDiskEncryption detects FileVault encryption status on macOS.
func detectDiskEncryption(ctx context.Context) DiskEncryptionInfo {
info := DiskEncryptionInfo{}
cmdCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
cmd := exec.CommandContext(cmdCtx, "fdesetup", "status")
output, err := cmd.Output()
if err != nil {
log.Debugf("execute fdesetup: %v", err)
return info
}
encrypted := strings.Contains(string(output), "FileVault is On")
info.Volumes = append(info.Volumes, DiskEncryptionVolume{
Path: "/",
Encrypted: encrypted,
})
return info
}

View File

@@ -0,0 +1,98 @@
//go:build linux && !android
package system
import (
"bufio"
"context"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
// detectDiskEncryption detects LUKS encryption status on Linux by reading sysfs.
func detectDiskEncryption(ctx context.Context) DiskEncryptionInfo {
info := DiskEncryptionInfo{}
encryptedDevices := findEncryptedDevices()
mountPoints := parseMounts(encryptedDevices)
info.Volumes = mountPoints
return info
}
// findEncryptedDevices scans /sys/block for dm-crypt (LUKS) encrypted devices.
func findEncryptedDevices() map[string]bool {
encryptedDevices := make(map[string]bool)
sysBlock := "/sys/block"
entries, err := os.ReadDir(sysBlock)
if err != nil {
log.Debugf("read /sys/block: %v", err)
return encryptedDevices
}
for _, entry := range entries {
dmUuidPath := filepath.Join(sysBlock, entry.Name(), "dm", "uuid")
data, err := os.ReadFile(dmUuidPath)
if err != nil {
continue
}
uuid := strings.TrimSpace(string(data))
if strings.HasPrefix(uuid, "CRYPT-") {
dmNamePath := filepath.Join(sysBlock, entry.Name(), "dm", "name")
if nameData, err := os.ReadFile(dmNamePath); err == nil {
dmName := strings.TrimSpace(string(nameData))
encryptedDevices["/dev/mapper/"+dmName] = true
}
encryptedDevices["/dev/"+entry.Name()] = true
}
}
return encryptedDevices
}
// parseMounts reads /proc/mounts and maps devices to mount points with encryption status.
func parseMounts(encryptedDevices map[string]bool) []DiskEncryptionVolume {
var volumes []DiskEncryptionVolume
mountsFile, err := os.Open("/proc/mounts")
if err != nil {
log.Debugf("open /proc/mounts: %v", err)
return volumes
}
defer func() {
if err := mountsFile.Close(); err != nil {
log.Debugf("close /proc/mounts: %v", err)
}
}()
scanner := bufio.NewScanner(mountsFile)
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if len(fields) < 2 {
continue
}
device, mountPoint := fields[0], fields[1]
encrypted := encryptedDevices[device]
if !encrypted && strings.HasPrefix(device, "/dev/mapper/") {
for encDev := range encryptedDevices {
if device == encDev {
encrypted = true
break
}
}
}
volumes = append(volumes, DiskEncryptionVolume{
Path: mountPoint,
Encrypted: encrypted,
})
}
return volumes
}

View File

@@ -0,0 +1,10 @@
//go:build android || ios || freebsd || js
package system
import "context"
// detectDiskEncryption is a stub for unsupported platforms.
func detectDiskEncryption(_ context.Context) DiskEncryptionInfo {
return DiskEncryptionInfo{}
}

View File

@@ -0,0 +1,41 @@
//go:build windows
package system
import (
"context"
"strings"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
)
// Win32EncryptableVolume represents the WMI class for BitLocker status.
type Win32EncryptableVolume struct {
DriveLetter string
ProtectionStatus uint32
}
// detectDiskEncryption detects BitLocker encryption status on Windows via WMI.
func detectDiskEncryption(_ context.Context) DiskEncryptionInfo {
info := DiskEncryptionInfo{}
var volumes []Win32EncryptableVolume
query := "SELECT DriveLetter, ProtectionStatus FROM Win32_EncryptableVolume"
err := wmi.QueryNamespace(query, &volumes, `root\CIMV2\Security\MicrosoftVolumeEncryption`)
if err != nil {
log.Debugf("query BitLocker status: %v", err)
return info
}
for _, vol := range volumes {
driveLetter := strings.TrimSuffix(vol.DriveLetter, "\\")
info.Volumes = append(info.Volumes, DiskEncryptionVolume{
Path: driveLetter,
Encrypted: vol.ProtectionStatus == 1,
})
}
return info
}

View File

@@ -59,6 +59,7 @@ type Info struct {
SystemManufacturer string
Environment Environment
Files []File // for posture checks
DiskEncryption DiskEncryptionInfo
RosenpassEnabled bool
RosenpassPermissive bool

View File

@@ -44,6 +44,7 @@ func GetInfo(ctx context.Context) *Info {
SystemSerialNumber: serial(),
SystemProductName: productModel(),
SystemManufacturer: productManufacturer(),
DiskEncryption: detectDiskEncryption(ctx),
}
return gio

View File

@@ -62,6 +62,7 @@ func GetInfo(ctx context.Context) *Info {
SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment,
DiskEncryption: detectDiskEncryption(ctx),
}
systemHostname, _ := os.Hostname()

View File

@@ -55,6 +55,7 @@ func GetInfo(ctx context.Context) *Info {
UIVersion: extractUserAgent(ctx),
KernelVersion: osInfo[1],
Environment: env,
DiskEncryption: detectDiskEncryption(ctx),
}
}

View File

@@ -19,7 +19,7 @@ func GetInfo(ctx context.Context) *Info {
sysName := extractOsName(ctx, "sysName")
swVersion := extractOsVersion(ctx, "swVersion")
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion}
gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion, DiskEncryption: detectDiskEncryption(ctx)}
gio.Hostname = extractDeviceName(ctx, "hostname")
gio.NetbirdVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx)

View File

@@ -15,7 +15,7 @@ func UpdateStaticInfoAsync() {
}
// GetInfo retrieves system information for WASM environment
func GetInfo(_ context.Context) *Info {
func GetInfo(ctx context.Context) *Info {
info := &Info{
GoOS: runtime.GOOS,
Kernel: runtime.GOARCH,
@@ -25,6 +25,7 @@ func GetInfo(_ context.Context) *Info {
Hostname: "wasm-client",
CPUs: runtime.NumCPU(),
NetbirdVersion: version.NetbirdVersion(),
DiskEncryption: detectDiskEncryption(ctx),
}
collectBrowserInfo(info)

View File

@@ -73,6 +73,7 @@ func GetInfo(ctx context.Context) *Info {
SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment,
DiskEncryption: detectDiskEncryption(ctx),
}
return gio

View File

@@ -35,6 +35,7 @@ func GetInfo(ctx context.Context) *Info {
SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment,
DiskEncryption: detectDiskEncryption(ctx),
}
addrs, err := networkAddresses()

View File

@@ -441,7 +441,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string
if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
postUpStatusOutput = overview.FullDetailSummary()
}
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
@@ -458,7 +458,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string
if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
preDownStatusOutput = overview.FullDetailSummary()
}
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
time.Now().Format(time.RFC3339), params.duration)
@@ -595,7 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string
if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
statusOutput = overview.FullDetailSummary()
}
request := &proto.DebugBundleRequest{

View File

@@ -9,20 +9,29 @@ import (
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protojson"
netbird "github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/client/proto"
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
const (
clientStartTimeout = 30 * time.Second
clientStopTimeout = 10 * time.Second
pingTimeout = 10 * time.Second
defaultLogLevel = "warn"
defaultSSHDetectionTimeout = 20 * time.Second
icmpEchoRequest = 8
icmpCodeEcho = 0
pingBufferSize = 1500
)
func main() {
@@ -113,18 +122,45 @@ func createStopMethod(client *netbird.Client) js.Func {
})
}
// validateSSHArgs validates SSH connection arguments
func validateSSHArgs(args []js.Value) (host string, port int, username string, err js.Value) {
if len(args) < 2 {
return "", 0, "", js.ValueOf("error: requires host and port")
}
if args[0].Type() != js.TypeString {
return "", 0, "", js.ValueOf("host parameter must be a string")
}
if args[1].Type() != js.TypeNumber {
return "", 0, "", js.ValueOf("port parameter must be a number")
}
host = args[0].String()
port = args[1].Int()
username = "root"
if len(args) > 2 {
if args[2].Type() == js.TypeString && args[2].String() != "" {
username = args[2].String()
} else if args[2].Type() != js.TypeString {
return "", 0, "", js.ValueOf("username parameter must be a string")
}
}
return host, port, username, js.Undefined()
}
// createSSHMethod creates the SSH connection method
func createSSHMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: requires host and port")
host, port, username, validationErr := validateSSHArgs(args)
if !validationErr.IsUndefined() {
if validationErr.Type() == js.TypeString && validationErr.String() == "error: requires host and port" {
return validationErr
}
host := args[0].String()
port := args[1].Int()
username := "root"
if len(args) > 2 && args[2].String() != "" {
username = args[2].String()
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(validationErr)
})
}
var jwtToken string
@@ -154,6 +190,110 @@ func createSSHMethod(client *netbird.Client) js.Func {
})
}
func performPing(client *netbird.Client, hostname string) {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
start := time.Now()
conn, err := client.Dial(ctx, "ping", hostname)
if err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
return
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close ping connection: %v", err)
}
}()
icmpData := make([]byte, 8)
icmpData[0] = icmpEchoRequest
icmpData[1] = icmpCodeEcho
if _, err := conn.Write(icmpData); err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s write failed: %v", hostname, err))
return
}
buf := make([]byte, pingBufferSize)
if _, err := conn.Read(buf); err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s read failed: %v", hostname, err))
return
}
latency := time.Since(start)
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()))
}
func performPingTCP(client *netbird.Client, hostname string, port int) {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
address := fmt.Sprintf("%s:%d", hostname, port)
start := time.Now()
conn, err := client.Dial(ctx, "tcp", address)
if err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
return
}
latency := time.Since(start)
if err := conn.Close(); err != nil {
log.Debugf("failed to close TCP connection: %v", err)
}
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()))
}
// createPingMethod creates the ping method
func createPingMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: hostname required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
hostname := args[0].String()
return createPromise(func(resolve, reject js.Value) {
performPing(client, hostname)
resolve.Invoke(js.Undefined())
})
})
}
// createPingTCPMethod creates the pingtcp method
func createPingTCPMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: hostname and port required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
if args[1].Type() != js.TypeNumber {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("port parameter must be a number"))
})
}
hostname := args[0].String()
port := args[1].Int()
return createPromise(func(resolve, reject js.Value) {
performPingTCP(client, hostname, port)
resolve.Invoke(js.Undefined())
})
})
}
// createProxyRequestMethod creates the proxyRequest method
func createProxyRequestMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
@@ -162,6 +302,11 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
}
request := args[0]
if request.Type() != js.TypeObject {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("request parameter must be an object"))
})
}
return createPromise(func(resolve, reject js.Value) {
response, err := http.ProxyRequest(client, request)
@@ -181,11 +326,145 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
return js.ValueOf("error: hostname and port required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
if args[1].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("port parameter must be a string"))
})
}
proxy := rdp.NewRDCleanPathProxy(client)
return proxy.CreateProxy(args[0].String(), args[1].String())
})
}
// getStatusOverview is a helper to get the status overview
func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) {
fullStatus, err := client.Status()
if err != nil {
return nbstatus.OutputOverview{}, err
}
pbFullStatus := fullStatus.ToProto()
statusResp := &proto.StatusResponse{
DaemonVersion: version.NetbirdVersion(),
FullStatus: pbFullStatus,
}
return nbstatus.ConvertToStatusOutputOverview(statusResp, false, "", nil, nil, nil, "", ""), nil
}
// createStatusMethod creates the status method that returns JSON
func createStatusMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
jsonStr, err := overview.JSON()
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
jsonObj := js.Global().Get("JSON").Call("parse", jsonStr)
resolve.Invoke(jsonObj)
})
})
}
// createStatusSummaryMethod creates the statusSummary method
func createStatusSummaryMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
summary := overview.GeneralSummary(false, false, false, false)
js.Global().Get("console").Call("log", summary)
resolve.Invoke(js.Undefined())
})
})
}
// createStatusDetailMethod creates the statusDetail method
func createStatusDetailMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
detail := overview.FullDetailSummary()
js.Global().Get("console").Call("log", detail)
resolve.Invoke(js.Undefined())
})
})
}
// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON
func createGetSyncResponseMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
syncResp, err := client.GetLatestSyncResponse()
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
options := protojson.MarshalOptions{
EmitUnpopulated: true,
UseProtoNames: true,
AllowPartial: true,
}
jsonBytes, err := options.Marshal(syncResp)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("marshal sync response: %v", err)))
return
}
jsonObj := js.Global().Get("JSON").Call("parse", string(jsonBytes))
resolve.Invoke(jsonObj)
})
})
}
// createSetLogLevelMethod creates the setLogLevel method to dynamically change logging level
func createSetLogLevelMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: log level required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("log level parameter must be a string"))
})
}
logLevel := args[0].String()
return createPromise(func(resolve, reject js.Value) {
if err := client.SetLogLevel(logLevel); err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("set log level: %v", err)))
return
}
log.Infof("Log level set to: %s", logLevel)
resolve.Invoke(js.ValueOf(true))
})
})
}
// createPromise is a helper to create JavaScript promises
func createPromise(handler func(resolve, reject js.Value)) js.Value {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
@@ -237,17 +516,24 @@ func createClientObject(client *netbird.Client) js.Value {
obj["start"] = createStartMethod(client)
obj["stop"] = createStopMethod(client)
obj["ping"] = createPingMethod(client)
obj["pingtcp"] = createPingTCPMethod(client)
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client)
obj["status"] = createStatusMethod(client)
obj["statusSummary"] = createStatusSummaryMethod(client)
obj["statusDetail"] = createStatusDetailMethod(client)
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
obj["setLogLevel"] = createSetLogLevelMethod(client)
return js.ValueOf(obj)
}
// netBirdClientConstructor acts as a JavaScript constructor function
func netBirdClientConstructor(this js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0]
reject := promiseArgs[1]

8
go.mod
View File

@@ -78,8 +78,8 @@ require (
github.com/pion/logging v0.2.4
github.com/pion/randutil v0.1.0
github.com/pion/stun/v2 v2.0.0
github.com/pion/stun/v3 v3.0.0
github.com/pion/transport/v3 v3.0.7
github.com/pion/stun/v3 v3.1.0
github.com/pion/transport/v3 v3.1.1
github.com/pion/turn/v3 v3.0.1
github.com/pkg/sftp v1.13.9
github.com/prometheus/client_golang v1.23.2
@@ -241,7 +241,7 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/dtls/v3 v3.0.7 // indirect
github.com/pion/dtls/v3 v3.0.9 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pion/turn/v4 v4.1.1 // indirect
@@ -263,7 +263,7 @@ require (
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wlynxg/anet v0.0.3 // indirect
github.com/wlynxg/anet v0.0.5 // indirect
github.com/yuin/goldmark v1.7.8 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect

16
go.sum
View File

@@ -444,8 +444,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q=
github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8=
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
@@ -455,14 +455,14 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
github.com/pion/stun/v3 v3.1.0 h1:bS1jjT3tGWZ4UPmIUeyalOylamTMTFg1OvXtY/r6seM=
github.com/pion/stun/v3 v3.1.0/go.mod h1:egmx1CUcfSSGJxQCOjtVlomfPqmQ58BibPyuOWNGQEU=
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
@@ -574,8 +574,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=

View File

@@ -798,15 +798,15 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
"redirectURI": redirectURI,
"scopes": []string{"openid", "profile", "email"},
"insecureEnableGroups": true,
//some providers don't return email verified, so we need to skip it if not present (e.g., Entra, Okta, Duo)
"insecureSkipEmailVerified": true,
}
switch cfg.Type {
case "zitadel":
oidcConfig["getUserInfo"] = true
case "entra":
oidcConfig["insecureSkipEmailVerified"] = true
oidcConfig["claimMapping"] = map[string]string{"email": "preferred_username"}
case "okta":
oidcConfig["insecureSkipEmailVerified"] = true
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}

File diff suppressed because it is too large Load Diff

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
@@ -175,7 +176,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
@@ -197,6 +198,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return fmt.Errorf("failed to get account zones: %v", err)
}
for _, peer := range account.Peers {
if !c.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
@@ -223,9 +230,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -318,7 +325,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
@@ -335,12 +342,18 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err
}
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return err
}
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -434,7 +447,14 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
@@ -445,11 +465,11 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -472,7 +492,8 @@ func (c *Controller) getPeerNetworkMapExp(
accountId string,
peerId string,
validatedPeers map[string]struct{},
customZone nbdns.CustomZone,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(ctx, accountId)
@@ -483,7 +504,7 @@ func (c *Controller) getPeerNetworkMapExp(
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
@@ -798,7 +819,15 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
if err != nil {
return nil, err
}
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
@@ -809,11 +838,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -3,6 +3,7 @@ package controller
import (
"context"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -14,6 +15,7 @@ type Repository interface {
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
}
type repository struct {
@@ -47,3 +49,7 @@ func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerID
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
}

View File

@@ -0,0 +1,13 @@
package zones
import (
"context"
)
type Manager interface {
GetAllZones(ctx context.Context, accountID, userID string) ([]*Zone, error)
GetZone(ctx context.Context, accountID, userID, zone string) (*Zone, error)
CreateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
UpdateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
DeleteZone(ctx context.Context, accountID, userID, zoneID string) error
}

View File

@@ -0,0 +1,161 @@
package manager
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/zones"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager zones.Manager
}
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiZones := make([]*api.Zone, 0, len(allZones))
for _, zone := range allZones {
apiZones = append(apiZones, zone.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiZones)
}
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.PostApiDnsZonesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
zone := new(zones.Zone)
zone.FromAPIRequest(&req)
if err = zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
createdZone, err := h.manager.CreateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
}
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
zone, err := h.manager.GetZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
}
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
var req api.PutApiDnsZonesZoneIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
zone := new(zones.Zone)
zone.FromAPIRequest(&req)
zone.ID = zoneID
if err = zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
updatedZone, err := h.manager.UpdateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
}
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -0,0 +1,229 @@
package manager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
dnsDomain string
}
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
dnsDomain: dnsDomain,
}
}
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
}
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
return nil, err
}
zone = zones.NewZone(accountID, zone.Name, zone.Domain, zone.Enabled, zone.EnableSearchDomain, zone.DistributionGroups)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, zone.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing zone: %w", err)
}
}
if existingZone != nil {
return status.Errorf(status.AlreadyExists, "zone with domain %s already exists", zone.Domain)
}
for _, groupID := range zone.DistributionGroups {
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
}
if err = transaction.CreateZone(ctx, zone); err != nil {
return fmt.Errorf("failed to create zone: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneCreated, zone.EventMeta())
return zone, nil
}
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
if err != nil {
return nil, fmt.Errorf("failed to get zone: %w", err)
}
if zone.Domain != updatedZone.Domain {
return nil, status.Errorf(status.InvalidArgument, "zone domain cannot be updated")
}
zone.Name = updatedZone.Name
zone.Enabled = updatedZone.Enabled
zone.EnableSearchDomain = updatedZone.EnableSearchDomain
zone.DistributionGroups = updatedZone.DistributionGroups
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range zone.DistributionGroups {
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
}
if err = transaction.UpdateZone(ctx, zone); err != nil {
return fmt.Errorf("failed to update zone: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return zone, nil
}
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
records, err := transaction.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get records: %w", err)
}
err = transaction.DeleteZoneDNSRecords(ctx, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to delete zone dns records: %w", err)
}
err = transaction.DeleteZone(ctx, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to delete zone: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
for _, record := range records {
eventsToStore = append(eventsToStore, func() {
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordDeleted, meta)
})
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, zoneID, accountID, activity.DNSZoneDeleted, zone.EventMeta())
})
return nil
})
if err != nil {
return err
}
for _, event := range eventsToStore {
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) validateZoneDomainConflict(ctx context.Context, accountID, domain string) error {
if m.dnsDomain != "" && m.dnsDomain == domain {
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
}
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
if settings.DNSDomain != "" && settings.DNSDomain == domain {
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
}
return nil
}

View File

@@ -0,0 +1,553 @@
package manager
import (
"context"
"fmt"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
testZoneID = "test-zone-id"
testGroupID = "test-group-id"
testDNSDomain = "netbird.selfhosted"
)
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
Name: "Test Group",
},
},
})
require.NoError(t, err)
ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
dnsDomain: testDNSDomain,
}
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
}
func TestManagerImpl_GetAllZones(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone1 := zones.NewZone(testAccountID, "Zone 1", "zone1.example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone1)
require.NoError(t, err)
zone2 := zones.NewZone(testAccountID, "Zone 2", "zone2.example.com", false, false, []string{testGroupID})
err = testStore.CreateZone(ctx, zone2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.NoError(t, err)
assert.Len(t, result, 2)
assert.Equal(t, zone1.ID, result[0].ID)
assert.Equal(t, zone2.ID, result[1].ID)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_GetZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone := zones.NewZone(testAccountID, "Test Zone", "test.example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Equal(t, zone.ID, result.ID)
assert.Equal(t, zone.Name, result.Name)
assert.Equal(t, zone.Domain, result.Domain)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
}
func TestManagerImpl_CreateZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSZoneCreated, activityID)
}
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.NoError(t, err)
assert.NotNil(t, result)
assert.NotEmpty(t, result.ID)
assert.Equal(t, testAccountID, result.AccountID)
assert.Equal(t, inputZone.Name, result.Name)
assert.Equal(t, inputZone.Domain, result.Domain)
assert.Equal(t, inputZone.Enabled, result.Enabled)
assert.Equal(t, inputZone.EnableSearchDomain, result.EnableSearchDomain)
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("invalid group", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
DistributionGroups: []string{"invalid-group"},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
})
t.Run("duplicate domain", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingZone := zones.NewZone(testAccountID, "Existing Zone", "duplicate.example.com", true, false, []string{testGroupID})
err := testStore.CreateZone(ctx, existingZone)
require.NoError(t, err)
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "duplicate.example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "zone with domain duplicate.example.com already exists")
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.AlreadyExists, s.Type())
})
t.Run("peer DNS domain conflict", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
account, err := testStore.GetAccount(ctx, testAccountID)
require.NoError(t, err)
account.Settings.DNSDomain = "peers.example.com"
err = testStore.SaveAccount(ctx, account)
require.NoError(t, err)
inputZone := &zones.Zone{
Name: "Test Zone",
Domain: "peers.example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "zone domain peers.example.com conflicts with peer DNS domain")
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.InvalidArgument, s.Type())
})
t.Run("default DNS domain conflict", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "Test Zone",
Domain: testDNSDomain,
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), fmt.Sprintf("zone domain %s conflicts with peer DNS domain", testDNSDomain))
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.InvalidArgument, s.Type())
})
}
func TestManagerImpl_UpdateZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingZone := zones.NewZone(testAccountID, "Old Name", "example.com", false, false, []string{testGroupID})
err := testStore.CreateZone(ctx, existingZone)
require.NoError(t, err)
updatedZone := &zones.Zone{
ID: existingZone.ID,
Name: "Updated Name",
Domain: "example.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, existingZone.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSZoneUpdated, activityID)
}
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, updatedZone.Name, result.Name)
assert.Equal(t, updatedZone.Enabled, result.Enabled)
assert.Equal(t, updatedZone.EnableSearchDomain, result.EnableSearchDomain)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
})
t.Run("domain change not allowed", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingZone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, existingZone)
require.NoError(t, err)
updatedZone := &zones.Zone{
ID: existingZone.ID,
Name: "Test Zone",
Domain: "different.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "zone domain cannot be updated")
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.InvalidArgument, s.Type())
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedZone := &zones.Zone{
ID: testZoneID,
Name: "Updated Name",
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedZone := &zones.Zone{
ID: "non-existent-zone",
Name: "Updated Name",
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_DeleteZone(t *testing.T) {
ctx := context.Background()
t.Run("success with records", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = testStore.CreateDNSRecord(ctx, record1)
require.NoError(t, err)
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCallCount := 0
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCallCount++
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
}
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Equal(t, 3, storeEventCallCount)
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
require.Error(t, err)
zoneRecords, err := testStore.GetZoneDNSRecords(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
require.NoError(t, err)
assert.Empty(t, zoneRecords)
})
t.Run("success without records", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, zone.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSZoneDeleted, activityID)
}
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
require.Error(t, err)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
require.Error(t, err)
})
}

View File

@@ -0,0 +1,13 @@
package records
import (
"context"
)
type Manager interface {
GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*Record, error)
GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*Record, error)
CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
UpdateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error
}

View File

@@ -0,0 +1,191 @@
package manager
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager records.Manager
}
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
allRecords, err := h.manager.GetAllRecords(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiRecords := make([]*api.DNSRecord, 0, len(allRecords))
for _, record := range allRecords {
apiRecords = append(apiRecords, record.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiRecords)
}
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
record := new(records.Record)
record.FromAPIRequest(&req)
if err = record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
createdRecord, err := h.manager.CreateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
}
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
recordID := mux.Vars(r)["recordId"]
if recordID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
return
}
record, err := h.manager.GetRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
}
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
recordID := mux.Vars(r)["recordId"]
if recordID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
return
}
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
record := new(records.Record)
record.FromAPIRequest(&req)
record.ID = recordID
if err = record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
updatedRecord, err := h.manager.UpdateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
}
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
recordID := mux.Vars(r)["recordId"]
if recordID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
return
}
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -0,0 +1,236 @@
package manager
import (
"context"
"fmt"
"strings"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
}
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
}
}
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
}
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
}
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
err = validateRecordConflicts(ctx, transaction, zone, record)
if err != nil {
return err
}
if err = transaction.CreateDNSRecord(ctx, record); err != nil {
return fmt.Errorf("failed to create dns record: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return record, nil
}
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone
var record *records.Record
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, updatedRecord.ID)
if err != nil {
return fmt.Errorf("failed to get record: %w", err)
}
hasChanges := record.Name != updatedRecord.Name || record.Type != updatedRecord.Type || record.Content != updatedRecord.Content
record.Name = updatedRecord.Name
record.Type = updatedRecord.Type
record.Content = updatedRecord.Content
record.TTL = updatedRecord.TTL
if hasChanges {
if err = validateRecordConflicts(ctx, transaction, zone, record); err != nil {
return err
}
}
if err = transaction.UpdateDNSRecord(ctx, record); err != nil {
return fmt.Errorf("failed to update dns record: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return record, nil
}
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var record *records.Record
var zone *zones.Zone
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, recordID)
if err != nil {
return fmt.Errorf("failed to get record: %w", err)
}
err = transaction.DeleteDNSRecord(ctx, accountID, zoneID, recordID)
if err != nil {
return fmt.Errorf("failed to delete dns record: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return err
}
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// validateRecordConflicts checks for duplicate records and CNAME conflicts
func validateRecordConflicts(ctx context.Context, transaction store.Store, zone *zones.Zone, record *records.Record) error {
if record.Name != zone.Domain && !strings.HasSuffix(record.Name, "."+zone.Domain) {
return status.Errorf(status.InvalidArgument, "record name does not belong to zone")
}
existingRecords, err := transaction.GetZoneDNSRecordsByName(ctx, store.LockingStrengthNone, zone.AccountID, zone.ID, record.Name)
if err != nil {
return fmt.Errorf("failed to check existing records: %w", err)
}
for _, existing := range existingRecords {
if existing.ID == record.ID {
continue
}
if existing.Type == record.Type && existing.Content == record.Content {
return status.Errorf(status.AlreadyExists, "identical record already exists")
}
if record.Type == records.RecordTypeCNAME || existing.Type == records.RecordTypeCNAME {
return status.Errorf(status.InvalidArgument,
"An A, AAAA, or CNAME record with name %s already exists", record.Name)
}
}
return nil
}

View File

@@ -0,0 +1,573 @@
package manager
import (
"context"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
testRecordID = "test-record-id"
testGroupID = "test-group-id"
)
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
Name: "Test Group",
},
},
})
require.NoError(t, err)
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err = testStore.CreateZone(ctx, zone)
require.NoError(t, err)
ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
}
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
}
func TestManagerImpl_GetAllRecords(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record1)
require.NoError(t, err)
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Len(t, result, 2)
assert.Equal(t, record1.ID, result[0].ID)
assert.Equal(t, record2.ID, result[1].ID)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_GetRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err)
assert.Equal(t, record.ID, result.ID)
assert.Equal(t, record.Name, result.Name)
assert.Equal(t, record.Type, result.Type)
assert.Equal(t, record.Content, result.Content)
assert.Equal(t, record.TTL, result.TTL)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
}
func TestManagerImpl_CreateRecord(t *testing.T) {
ctx := context.Background()
t.Run("success - A record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordCreated, activityID)
}
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.NotEmpty(t, result.ID)
assert.Equal(t, testAccountID, result.AccountID)
assert.Equal(t, zone.ID, result.ZoneID)
assert.Equal(t, inputRecord.Name, result.Name)
assert.Equal(t, inputRecord.Type, result.Type)
assert.Equal(t, inputRecord.Content, result.Content)
assert.Equal(t, inputRecord.TTL, result.TTL)
})
t.Run("success - AAAA record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "ipv6.example.com",
Type: records.RecordTypeAAAA,
Content: "2001:db8::1",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordCreated, activityID)
}
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, inputRecord.Type, result.Type)
assert.Equal(t, inputRecord.Content, result.Content)
})
t.Run("success - CNAME record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "www.example.com",
Type: records.RecordTypeCNAME,
Content: "example.com",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordCreated, activityID)
}
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, inputRecord.Type, result.Type)
assert.Equal(t, inputRecord.Content, result.Content)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record name not in zone", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.different.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "does not belong to zone")
})
t.Run("duplicate record", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "identical record already exists")
})
t.Run("CNAME conflict with existing A record", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeCNAME,
Content: "example.com",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "already exists")
})
}
func TestManagerImpl_UpdateRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
updatedRecord := &records.Record{
ID: existingRecord.ID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100", // Changed IP
TTL: 600, // Changed TTL
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, existingRecord.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordUpdated, activityID)
}
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, updatedRecord.Content, result.Content)
assert.Equal(t, updatedRecord.TTL, result.TTL)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
})
t.Run("update only TTL - no validation", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
updatedRecord := &records.Record{
ID: existingRecord.ID,
Name: existingRecord.Name,
Type: existingRecord.Type,
Content: existingRecord.Content,
TTL: 600, // Only TTL changed
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
// Event should be stored
}
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, 600, result.TTL)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedRecord := &records.Record{
ID: testRecordID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedRecord := &records.Record{
ID: "non-existent-record",
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
})
t.Run("update creates duplicate", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record1)
require.NoError(t, err)
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
updatedRecord := &records.Record{
ID: record2.ID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "identical record already exists")
})
}
func TestManagerImpl_DeleteRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, record.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordDeleted, activityID)
}
err = manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
_, err = testStore.GetDNSRecordByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID, record.ID)
require.Error(t, err)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
require.Error(t, err)
})
}

View File

@@ -0,0 +1,129 @@
package records
import (
"errors"
"net"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
type RecordType string
const (
RecordTypeA RecordType = "A"
RecordTypeAAAA RecordType = "AAAA"
RecordTypeCNAME RecordType = "CNAME"
)
type Record struct {
AccountID string `gorm:"index"`
ZoneID string `gorm:"index"`
ID string `gorm:"primaryKey"`
Name string
Type RecordType
Content string
TTL int
}
func NewRecord(accountID, zoneID, name string, recordType RecordType, content string, ttl int) *Record {
return &Record{
ID: xid.New().String(),
AccountID: accountID,
ZoneID: zoneID,
Name: name,
Type: recordType,
Content: content,
TTL: ttl,
}
}
func (r *Record) ToAPIResponse() *api.DNSRecord {
recordType := api.DNSRecordType(r.Type)
return &api.DNSRecord{
Id: r.ID,
Name: r.Name,
Type: recordType,
Content: r.Content,
Ttl: r.TTL,
}
}
func (r *Record) FromAPIRequest(req *api.DNSRecordRequest) {
r.Name = req.Name
r.Type = RecordType(req.Type)
r.Content = req.Content
r.TTL = req.Ttl
}
func (r *Record) Validate() error {
if r.Name == "" {
return errors.New("record name is required")
}
if !util.IsValidDomain(r.Name) {
return errors.New("invalid record name format")
}
if r.Type == "" {
return errors.New("record type is required")
}
switch r.Type {
case RecordTypeA:
if err := validateIPv4(r.Content); err != nil {
return err
}
case RecordTypeAAAA:
if err := validateIPv6(r.Content); err != nil {
return err
}
case RecordTypeCNAME:
if !util.IsValidDomain(r.Content) {
return errors.New("invalid CNAME record format")
}
default:
return errors.New("invalid record type, must be A, AAAA, or CNAME")
}
if r.TTL < 0 {
return errors.New("TTL cannot be negative")
}
return nil
}
func (r *Record) EventMeta(zoneID, zoneName string) map[string]any {
return map[string]any{
"name": r.Name,
"type": string(r.Type),
"content": r.Content,
"ttl": r.TTL,
"zone_id": zoneID,
"zone_name": zoneName,
}
}
func validateIPv4(content string) error {
if content == "" {
return errors.New("A record is required") //nolint:staticcheck
}
ip := net.ParseIP(content)
if ip == nil || ip.To4() == nil {
return errors.New("A record must be a valid IPv4 address") //nolint:staticcheck
}
return nil
}
func validateIPv6(content string) error {
if content == "" {
return errors.New("AAAA record is required")
}
ip := net.ParseIP(content)
if ip == nil || ip.To4() != nil {
return errors.New("AAAA record must be a valid IPv6 address")
}
return nil
}

View File

@@ -0,0 +1,89 @@
package zones
import (
"errors"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
type Zone struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string
Enabled bool
EnableSearchDomain bool
DistributionGroups []string `gorm:"serializer:json"`
Records []*records.Record `gorm:"foreignKey:ZoneID;references:ID"`
}
func NewZone(accountID, name, domain string, enabled, enableSearchDomain bool, distributionGroups []string) *Zone {
return &Zone{
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Domain: domain,
Enabled: enabled,
EnableSearchDomain: enableSearchDomain,
DistributionGroups: distributionGroups,
}
}
func (z *Zone) ToAPIResponse() *api.Zone {
apiRecords := make([]api.DNSRecord, 0, len(z.Records))
for _, record := range z.Records {
if apiRecord := record.ToAPIResponse(); apiRecord != nil {
apiRecords = append(apiRecords, *apiRecord)
}
}
return &api.Zone{
DistributionGroups: z.DistributionGroups,
Domain: z.Domain,
EnableSearchDomain: z.EnableSearchDomain,
Enabled: z.Enabled,
Id: z.ID,
Name: z.Name,
Records: apiRecords,
}
}
func (z *Zone) FromAPIRequest(req *api.ZoneRequest) {
z.Name = req.Name
z.Domain = req.Domain
z.EnableSearchDomain = req.EnableSearchDomain
z.DistributionGroups = req.DistributionGroups
enabled := true
if req.Enabled != nil {
enabled = *req.Enabled
}
z.Enabled = enabled
}
func (z *Zone) Validate() error {
if z.Name == "" {
return errors.New("zone name is required")
}
if len(z.Name) > 255 {
return errors.New("zone name exceeds maximum length of 255 characters")
}
if !util.IsValidDomain(z.Domain) {
return errors.New("invalid zone domain format")
}
if len(z.DistributionGroups) == 0 {
return errors.New("at least one distribution group is required")
}
return nil
}
func (z *Zone) EventMeta() map[string]any {
return map[string]any{"name": z.Name, "domain": z.Domain}
}

View File

@@ -92,7 +92,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController(), s.IdpManager())
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}

View File

@@ -8,6 +8,10 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/geolocation"
@@ -158,3 +162,15 @@ func (s *BaseServer) NetworksManager() networks.Manager {
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
})
}
func (s *BaseServer) ZonesManager() zones.Manager {
return Create(s, func() zones.Manager {
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
})
}
func (s *BaseServer) RecordsManager() records.Manager {
return Create(s, func() records.Manager {
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
})
}

View File

@@ -376,6 +376,7 @@ func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
SearchDomainDisabled: zone.SearchDomainDisabled,
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
@@ -433,9 +434,16 @@ func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfi
if config.CLIAuthAudience != "" {
audience = config.CLIAuthAudience
}
audiences := []string{config.AuthAudience}
if config.CLIAuthAudience != "" && config.CLIAuthAudience != config.AuthAudience {
audiences = append(audiences, config.CLIAuthAudience)
}
return &proto.JWTConfig{
Issuer: issuer,
Audience: audience,
Audiences: audiences,
KeysLocation: keysLocation,
}
}

View File

@@ -6,9 +6,12 @@ import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
)
func TestToProtocolDNSConfigWithCache(t *testing.T) {
@@ -148,3 +151,51 @@ func generateTestData(size int) nbdns.Config {
return config
}
func TestBuildJWTConfig_Audiences(t *testing.T) {
tests := []struct {
name string
authAudience string
cliAuthAudience string
expectedAudiences []string
expectedAudience string
}{
{
name: "only_auth_audience",
authAudience: "dashboard-aud",
cliAuthAudience: "",
expectedAudiences: []string{"dashboard-aud"},
expectedAudience: "dashboard-aud",
},
{
name: "both_audiences_different",
authAudience: "dashboard-aud",
cliAuthAudience: "cli-aud",
expectedAudiences: []string{"dashboard-aud", "cli-aud"},
expectedAudience: "cli-aud",
},
{
name: "both_audiences_same",
authAudience: "same-aud",
cliAuthAudience: "same-aud",
expectedAudiences: []string{"same-aud"},
expectedAudience: "same-aud",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
config := &nbconfig.HttpServerConfig{
AuthIssuer: "https://issuer.example.com",
AuthAudience: tc.authAudience,
CLIAuthAudience: tc.cliAuthAudience,
}
result := buildJWTConfig(config, nil)
assert.NotNil(t, result)
assert.Equal(t, tc.expectedAudiences, result.Audiences, "audiences should match expected")
assert.Equal(t, tc.expectedAudience, result.Audience, "audience should match expected")
})
}
}

View File

@@ -470,6 +470,16 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
})
}
diskEncryptionVolumes := make([]nbpeer.DiskEncryptionVolume, 0)
if meta.GetDiskEncryption() != nil {
for _, vol := range meta.GetDiskEncryption().GetVolumes() {
diskEncryptionVolumes = append(diskEncryptionVolumes, nbpeer.DiskEncryptionVolume{
Path: vol.GetPath(),
Encrypted: vol.GetEncrypted(),
})
}
}
return nbpeer.PeerSystemMeta{
Hostname: meta.GetHostname(),
GoOS: meta.GetGoOS(),
@@ -501,6 +511,9 @@ func extractPeerMeta(ctx context.Context, meta *proto.PeerSystemMeta) nbpeer.Pee
LazyConnectionEnabled: meta.GetFlags().GetLazyConnectionEnabled(),
},
Files: files,
DiskEncryption: nbpeer.DiskEncryptionInfo{
Volumes: diskEncryptionVolumes,
},
}
}

View File

@@ -295,7 +295,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err
}
if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil {
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
return err
}
@@ -388,7 +388,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return newSettings, nil
}
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error {
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -402,6 +402,18 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, new
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
if newSettings.DNSDomain != oldSettings.DNSDomain && newSettings.DNSDomain != "" {
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, newSettings.DNSDomain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing zone: %w", err)
}
}
if existingZone != nil {
return status.Errorf(status.InvalidArgument, "peer DNS domain %s conflicts with existing custom DNS zone", newSettings.DNSDomain)
}
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID)
}

View File

@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -397,7 +398,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
}
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@@ -1676,7 +1677,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
},
}
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}})
routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2"))
assert.Len(t, routes, 2)
routeIDs := make(map[route.ID]struct{}, 2)
@@ -1686,7 +1687,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) {
assert.Contains(t, routeIDs, route.ID("route-2"))
assert.Contains(t, routeIDs, route.ID("route-3"))
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}})
emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3"))
assert.Len(t, emptyRoutes, 0)
}
@@ -2095,6 +2096,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T)
}
}
func TestDefaultAccountManager_UpdateAccountSettings_DNSDomainConflict(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
ctx := context.Background()
err = manager.Store.CreateZone(ctx, &zones.Zone{
ID: "test-zone-id",
AccountID: accountID,
Name: "Test Zone",
Domain: "custom.example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{},
})
require.NoError(t, err, "unable to create custom DNS zone")
_, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{
DNSDomain: "custom.example.com",
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.Error(t, err, "expecting to fail when DNS domain conflicts with custom zone")
assert.Contains(t, err.Error(), "conflicts with existing custom DNS zone")
}
func TestAccount_GetExpiredPeers(t *testing.T) {
type test struct {
name string

View File

@@ -187,6 +187,14 @@ const (
IdentityProviderUpdated Activity = 94
IdentityProviderDeleted Activity = 95
DNSZoneCreated Activity = 96
DNSZoneUpdated Activity = 97
DNSZoneDeleted Activity = 98
DNSRecordCreated Activity = 99
DNSRecordUpdated Activity = 100
DNSRecordDeleted Activity = 101
AccountDeleted Activity = 99999
)
@@ -303,6 +311,14 @@ var activityMap = map[Activity]Code{
IdentityProviderCreated: {"Identity provider created", "identityprovider.create"},
IdentityProviderUpdated: {"Identity provider updated", "identityprovider.update"},
IdentityProviderDeleted: {"Identity provider deleted", "identityprovider.delete"},
DNSZoneCreated: {"DNS zone created", "dns.zone.create"},
DNSZoneUpdated: {"DNS zone updated", "dns.zone.update"},
DNSZoneDeleted: {"DNS zone deleted", "dns.zone.delete"},
DNSRecordCreated: {"DNS zone record created", "dns.zone.record.create"},
DNSRecordUpdated: {"DNS zone record updated", "dns.zone.record.update"},
DNSRecordDeleted: {"DNS zone record deleted", "dns.zone.record.delete"},
}
// StringCode returns a string code of the activity

View File

@@ -26,6 +26,8 @@ type UserDataCache interface {
Get(ctx context.Context, key string) (*idp.UserData, error)
Set(ctx context.Context, key string, value *idp.UserData, expiration time.Duration) error
Delete(ctx context.Context, key string) error
GetUsers(ctx context.Context, key string) ([]*idp.UserData, error)
SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error
}
// UserDataCacheImpl is a struct that implements the UserDataCache interface.
@@ -51,6 +53,29 @@ func (u *UserDataCacheImpl) Delete(ctx context.Context, key string) error {
return u.cache.Delete(ctx, key)
}
func (u *UserDataCacheImpl) GetUsers(ctx context.Context, key string) ([]*idp.UserData, error) {
var users []*idp.UserData
v, err := u.cache.Get(ctx, key, &users)
if err != nil {
return nil, err
}
switch v := v.(type) {
case []*idp.UserData:
return v, nil
case *[]*idp.UserData:
return *v, nil
case []byte:
return unmarshalUserData(v)
}
return nil, fmt.Errorf("unexpected type: %T", v)
}
func (u *UserDataCacheImpl) SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error {
return u.cache.Set(ctx, key, users, store.WithExpiration(expiration))
}
// NewUserDataCache creates a new UserDataCacheImpl object.
func NewUserDataCache(store store.StoreInterface) *UserDataCacheImpl {
simpleCache := cache.New[any](store)

View File

@@ -15,7 +15,10 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/settings"
@@ -56,7 +59,7 @@ const (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -138,6 +141,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
dns.AddEndpoints(accountManager, router)
events.AddEndpoints(accountManager, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
zonesManager.RegisterEndpoints(router, zManager)
recordsManager.RegisterEndpoints(router, rManager)
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)

View File

@@ -10,6 +10,7 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
@@ -298,8 +299,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.networkMapController.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(r.Context(), dnsDomain)
netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}

View File

@@ -178,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
m.patUsageTracker.IncrementUsage(token)
}
if m.rateLimiter != nil {
if m.rateLimiter != nil && !isTerraformRequest(r) {
if !m.rateLimiter.Allow(token) {
return r, status.Errorf(status.TooManyRequests, "too many requests")
}
@@ -214,6 +214,11 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
}
func isTerraformRequest(r *http.Request) bool {
ua := strings.ToLower(r.Header.Get("User-Agent"))
return strings.Contains(ua, "terraform")
}
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
// the JWT token from the Authorization header.
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {

View File

@@ -508,6 +508,103 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
})
t.Run("Terraform User Agent Not Rate Limited", func(t *testing.T) {
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Test various Terraform user agent formats
terraformUserAgents := []string{
"Terraform/1.5.0",
"terraform/1.0.0",
"Terraform-Provider/2.0.0",
"Mozilla/5.0 (compatible; Terraform/1.3.0)",
}
for _, userAgent := range terraformUserAgents {
t.Run("UserAgent: "+userAgent, func(t *testing.T) {
successCount := 0
for i := 0; i < 10; i++ {
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
req.Header.Set("User-Agent", userAgent)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
successCount++
}
}
assert.Equal(t, 10, successCount, "All Terraform user agent requests should succeed (not rate limited)")
})
}
})
t.Run("Non-Terraform User Agent With PAT Is Rate Limited", func(t *testing.T) {
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
req.Header.Set("User-Agent", "curl/7.68.0")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
req.Header.Set("User-Agent", "curl/7.68.0")
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
})
}
func TestAuthMiddleware_Handler_Child(t *testing.T) {

View File

@@ -10,6 +10,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
@@ -93,8 +95,10 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
peersManager := peers.NewManager(store, permissionsManager)
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -136,9 +136,10 @@ func NewAuth0Manager(config Auth0ClientConfig, appMetrics telemetry.AppMetrics)
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Timeout: idpTimeout(),
Transport: httpTransport,
}
helper := JsonParser{}
if config.AuthIssuer == "" {

View File

@@ -48,13 +48,12 @@ type AuthentikCredentials struct {
}
// NewAuthentikManager creates a new instance of the AuthentikManager.
func NewAuthentikManager(config AuthentikClientConfig,
appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
func NewAuthentikManager(config AuthentikClientConfig, appMetrics telemetry.AppMetrics) (*AuthentikManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Timeout: idpTimeout(),
Transport: httpTransport,
}

View File

@@ -58,9 +58,10 @@ func NewAzureManager(config AzureClientConfig, appMetrics telemetry.AppMetrics)
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Timeout: idpTimeout(),
Transport: httpTransport,
}
helper := JsonParser{}
if config.ClientID == "" {

View File

@@ -5,7 +5,6 @@ import (
"encoding/base64"
"fmt"
"net/http"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2/google"
@@ -49,9 +48,10 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Timeout: idpTimeout(),
Transport: httpTransport,
}
helper := JsonParser{}
if config.CustomerID == "" {

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"net/http"
"strings"
"time"
v1 "github.com/TheJumpCloud/jcapi-go/v1"
@@ -46,9 +45,10 @@ func NewJumpCloudManager(config JumpCloudClientConfig, appMetrics telemetry.AppM
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Timeout: idpTimeout(),
Transport: httpTransport,
}
helper := JsonParser{}
if config.APIToken == "" {

View File

@@ -63,9 +63,10 @@ func NewKeycloakManager(config KeycloakClientConfig, appMetrics telemetry.AppMet
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Timeout: idpTimeout(),
Transport: httpTransport,
}
helper := JsonParser{}
if config.ClientID == "" {

View File

@@ -6,7 +6,6 @@ import (
"net/http"
"net/url"
"strings"
"time"
"github.com/okta/okta-sdk-golang/v2/okta"
"github.com/okta/okta-sdk-golang/v2/okta/query"
@@ -45,7 +44,7 @@ func NewOktaManager(config OktaClientConfig, appMetrics telemetry.AppMetrics) (*
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Timeout: idpTimeout(),
Transport: httpTransport,
}

View File

@@ -8,7 +8,6 @@ import (
"net/url"
"slices"
"strings"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -88,9 +87,10 @@ func NewPocketIdManager(config PocketIdClientConfig, appMetrics telemetry.AppMet
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Timeout: idpTimeout(),
Transport: httpTransport,
}
helper := JsonParser{}
if config.ManagementEndpoint == "" {

View File

@@ -4,7 +4,9 @@ import (
"encoding/json"
"math/rand"
"net/url"
"os"
"strings"
"time"
)
var (
@@ -69,3 +71,24 @@ func baseURL(rawURL string) string {
return parsedURL.Scheme + "://" + parsedURL.Host
}
const (
// Provides the env variable name for use with idpTimeout function
idpTimeoutEnv = "NB_IDP_TIMEOUT"
// Sets the defaultTimeout to 10s.
defaultTimeout = 10 * time.Second
)
// idpTimeout returns a timeout value for the IDP
func idpTimeout() time.Duration {
timeoutStr, ok := os.LookupEnv(idpTimeoutEnv)
if !ok || timeoutStr == "" {
return defaultTimeout
}
timeout, err := time.ParseDuration(timeoutStr)
if err != nil {
return defaultTimeout
}
return timeout
}

View File

@@ -164,9 +164,10 @@ func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetri
httpTransport.MaxIdleConns = 5
httpClient := &http.Client{
Timeout: 10 * time.Second,
Timeout: idpTimeout(),
Transport: httpTransport,
}
helper := JsonParser{}
hasPAT := config.PAT != ""

View File

@@ -95,6 +95,27 @@ type File struct {
ProcessIsRunning bool
}
// DiskEncryptionVolume represents encryption status of a volume.
type DiskEncryptionVolume struct {
Path string
Encrypted bool
}
// DiskEncryptionInfo holds encryption info for all volumes.
type DiskEncryptionInfo struct {
Volumes []DiskEncryptionVolume `gorm:"serializer:json"`
}
// IsEncrypted returns true if the volume at path is encrypted.
func (d DiskEncryptionInfo) IsEncrypted(path string) bool {
for _, v := range d.Volumes {
if v.Path == path {
return v.Encrypted
}
}
return false
}
// Flags defines a set of options to control feature behavior
type Flags struct {
RosenpassEnabled bool
@@ -130,6 +151,7 @@ type PeerSystemMeta struct { //nolint:revive
Environment Environment `gorm:"serializer:json"`
Flags Flags `gorm:"serializer:json"`
Files []File `gorm:"serializer:json"`
DiskEncryption DiskEncryptionInfo `gorm:"serializer:json"`
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
@@ -159,6 +181,19 @@ func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
return false
}
sort.Slice(p.DiskEncryption.Volumes, func(i, j int) bool {
return p.DiskEncryption.Volumes[i].Path < p.DiskEncryption.Volumes[j].Path
})
sort.Slice(other.DiskEncryption.Volumes, func(i, j int) bool {
return other.DiskEncryption.Volumes[i].Path < other.DiskEncryption.Volumes[j].Path
})
equalDiskEncryption := slices.EqualFunc(p.DiskEncryption.Volumes, other.DiskEncryption.Volumes, func(vol DiskEncryptionVolume, oVol DiskEncryptionVolume) bool {
return vol.Path == oVol.Path && vol.Encrypted == oVol.Encrypted
})
if !equalDiskEncryption {
return false
}
return p.Hostname == other.Hostname &&
p.GoOS == other.GoOS &&
p.Kernel == other.Kernel &&

View File

@@ -18,6 +18,7 @@ const (
GeoLocationCheckName = "GeoLocationCheck"
PeerNetworkRangeCheckName = "PeerNetworkRangeCheck"
ProcessCheckName = "ProcessCheck"
DiskEncryptionCheckName = "DiskEncryptionCheck"
CheckActionAllow string = "allow"
CheckActionDeny string = "deny"
@@ -58,6 +59,7 @@ type ChecksDefinition struct {
GeoLocationCheck *GeoLocationCheck `json:",omitempty"`
PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:",omitempty"`
ProcessCheck *ProcessCheck `json:",omitempty"`
DiskEncryptionCheck *DiskEncryptionCheck `json:",omitempty"`
}
// Copy returns a copy of a checks definition.
@@ -110,6 +112,13 @@ func (cd ChecksDefinition) Copy() ChecksDefinition {
}
copy(cdCopy.ProcessCheck.Processes, processCheck.Processes)
}
if cd.DiskEncryptionCheck != nil {
cdCopy.DiskEncryptionCheck = &DiskEncryptionCheck{
LinuxPath: cd.DiskEncryptionCheck.LinuxPath,
DarwinPath: cd.DiskEncryptionCheck.DarwinPath,
WindowsPath: cd.DiskEncryptionCheck.WindowsPath,
}
}
return cdCopy
}
@@ -153,6 +162,9 @@ func (pc *Checks) GetChecks() []Check {
if pc.Checks.ProcessCheck != nil {
checks = append(checks, pc.Checks.ProcessCheck)
}
if pc.Checks.DiskEncryptionCheck != nil {
checks = append(checks, pc.Checks.DiskEncryptionCheck)
}
return checks
}
@@ -208,6 +220,10 @@ func buildPostureCheck(postureChecksID string, name string, description string,
postureChecks.Checks.ProcessCheck = toProcessCheck(processCheck)
}
if diskEncryptionCheck := checks.DiskEncryptionCheck; diskEncryptionCheck != nil {
postureChecks.Checks.DiskEncryptionCheck = toDiskEncryptionCheck(diskEncryptionCheck)
}
return &postureChecks, nil
}
@@ -242,6 +258,10 @@ func (pc *Checks) ToAPIResponse() *api.PostureCheck {
checks.ProcessCheck = toProcessCheckResponse(pc.Checks.ProcessCheck)
}
if pc.Checks.DiskEncryptionCheck != nil {
checks.DiskEncryptionCheck = toDiskEncryptionCheckResponse(pc.Checks.DiskEncryptionCheck)
}
return &api.PostureCheck{
Id: pc.ID,
Name: pc.Name,
@@ -386,3 +406,25 @@ func toProcessCheck(check *api.ProcessCheck) *ProcessCheck {
Processes: processes,
}
}
func toDiskEncryptionCheck(check *api.DiskEncryptionCheck) *DiskEncryptionCheck {
d := &DiskEncryptionCheck{}
if check.LinuxPath != nil {
d.LinuxPath = *check.LinuxPath
}
if check.DarwinPath != nil {
d.DarwinPath = *check.DarwinPath
}
if check.WindowsPath != nil {
d.WindowsPath = *check.WindowsPath
}
return d
}
func toDiskEncryptionCheckResponse(check *DiskEncryptionCheck) *api.DiskEncryptionCheck {
return &api.DiskEncryptionCheck{
LinuxPath: &check.LinuxPath,
DarwinPath: &check.DarwinPath,
WindowsPath: &check.WindowsPath,
}
}

View File

@@ -0,0 +1,52 @@
package posture
import (
"context"
"fmt"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
// DiskEncryptionCheck verifies that specified volumes are encrypted.
type DiskEncryptionCheck struct {
LinuxPath string
DarwinPath string
WindowsPath string
}
var _ Check = (*DiskEncryptionCheck)(nil)
// Name returns the name of the check.
func (d *DiskEncryptionCheck) Name() string {
return DiskEncryptionCheckName
}
// Check performs the disk encryption verification for the given peer.
func (d *DiskEncryptionCheck) Check(_ context.Context, peer nbpeer.Peer) (bool, error) {
var pathToCheck string
switch peer.Meta.GoOS {
case "linux":
pathToCheck = d.LinuxPath
case "darwin":
pathToCheck = d.DarwinPath
case "windows":
pathToCheck = d.WindowsPath
default:
return false, nil
}
if pathToCheck == "" {
return true, nil
}
return peer.Meta.DiskEncryption.IsEncrypted(pathToCheck), nil
}
// Validate checks the configuration of the disk encryption check.
func (d *DiskEncryptionCheck) Validate() error {
if d.LinuxPath == "" && d.DarwinPath == "" && d.WindowsPath == "" {
return fmt.Errorf("%s at least one path must be configured", d.Name())
}
return nil
}

View File

@@ -0,0 +1,306 @@
package posture
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/peer"
)
func TestDiskEncryptionCheck_Check(t *testing.T) {
tests := []struct {
name string
input peer.Peer
check DiskEncryptionCheck
wantErr bool
isValid bool
}{
{
name: "linux with encrypted root",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
DiskEncryption: peer.DiskEncryptionInfo{
Volumes: []peer.DiskEncryptionVolume{
{Path: "/", Encrypted: true},
{Path: "/home", Encrypted: true},
},
},
},
},
check: DiskEncryptionCheck{
LinuxPath: "/",
},
wantErr: false,
isValid: true,
},
{
name: "linux with unencrypted root",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
DiskEncryption: peer.DiskEncryptionInfo{
Volumes: []peer.DiskEncryptionVolume{
{Path: "/", Encrypted: false},
{Path: "/home", Encrypted: true},
},
},
},
},
check: DiskEncryptionCheck{
LinuxPath: "/",
},
wantErr: false,
isValid: false,
},
{
name: "linux with no volume info",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
DiskEncryption: peer.DiskEncryptionInfo{},
},
},
check: DiskEncryptionCheck{
LinuxPath: "/",
},
wantErr: false,
isValid: false,
},
{
name: "darwin with encrypted root",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "darwin",
DiskEncryption: peer.DiskEncryptionInfo{
Volumes: []peer.DiskEncryptionVolume{
{Path: "/", Encrypted: true},
},
},
},
},
check: DiskEncryptionCheck{
DarwinPath: "/",
},
wantErr: false,
isValid: true,
},
{
name: "darwin with unencrypted root",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "darwin",
DiskEncryption: peer.DiskEncryptionInfo{
Volumes: []peer.DiskEncryptionVolume{
{Path: "/", Encrypted: false},
},
},
},
},
check: DiskEncryptionCheck{
DarwinPath: "/",
},
wantErr: false,
isValid: false,
},
{
name: "windows with encrypted C drive",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "windows",
DiskEncryption: peer.DiskEncryptionInfo{
Volumes: []peer.DiskEncryptionVolume{
{Path: "C:", Encrypted: true},
{Path: "D:", Encrypted: false},
},
},
},
},
check: DiskEncryptionCheck{
WindowsPath: "C:",
},
wantErr: false,
isValid: true,
},
{
name: "windows with unencrypted C drive",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "windows",
DiskEncryption: peer.DiskEncryptionInfo{
Volumes: []peer.DiskEncryptionVolume{
{Path: "C:", Encrypted: false},
{Path: "D:", Encrypted: true},
},
},
},
},
check: DiskEncryptionCheck{
WindowsPath: "C:",
},
wantErr: false,
isValid: false,
},
{
name: "unsupported ios operating system",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "ios",
},
},
check: DiskEncryptionCheck{
LinuxPath: "/",
DarwinPath: "/",
WindowsPath: "C:",
},
wantErr: false,
isValid: false,
},
{
name: "unsupported android operating system",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "android",
},
},
check: DiskEncryptionCheck{
LinuxPath: "/",
DarwinPath: "/",
WindowsPath: "C:",
},
wantErr: false,
isValid: false,
},
{
name: "linux peer with no linux path configured passes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "linux",
DiskEncryption: peer.DiskEncryptionInfo{
Volumes: []peer.DiskEncryptionVolume{
{Path: "/", Encrypted: false},
},
},
},
},
check: DiskEncryptionCheck{
DarwinPath: "/",
WindowsPath: "C:",
},
wantErr: false,
isValid: true,
},
{
name: "darwin peer with no darwin path configured passes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "darwin",
DiskEncryption: peer.DiskEncryptionInfo{
Volumes: []peer.DiskEncryptionVolume{
{Path: "/", Encrypted: false},
},
},
},
},
check: DiskEncryptionCheck{
LinuxPath: "/",
WindowsPath: "C:",
},
wantErr: false,
isValid: true,
},
{
name: "windows peer with no windows path configured passes",
input: peer.Peer{
Meta: peer.PeerSystemMeta{
GoOS: "windows",
DiskEncryption: peer.DiskEncryptionInfo{
Volumes: []peer.DiskEncryptionVolume{
{Path: "C:", Encrypted: false},
},
},
},
},
check: DiskEncryptionCheck{
LinuxPath: "/",
DarwinPath: "/",
},
wantErr: false,
isValid: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid, err := tt.check.Check(context.Background(), tt.input)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.isValid, isValid)
})
}
}
func TestDiskEncryptionCheck_Validate(t *testing.T) {
testCases := []struct {
name string
check DiskEncryptionCheck
expectedError bool
}{
{
name: "valid linux, darwin and windows paths",
check: DiskEncryptionCheck{
LinuxPath: "/",
DarwinPath: "/",
WindowsPath: "C:",
},
expectedError: false,
},
{
name: "valid linux path only",
check: DiskEncryptionCheck{
LinuxPath: "/",
},
expectedError: false,
},
{
name: "valid darwin path only",
check: DiskEncryptionCheck{
DarwinPath: "/",
},
expectedError: false,
},
{
name: "valid windows path only",
check: DiskEncryptionCheck{
WindowsPath: "C:",
},
expectedError: false,
},
{
name: "invalid empty paths",
check: DiskEncryptionCheck{},
expectedError: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.check.Validate()
if tc.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestDiskEncryptionCheck_Name(t *testing.T) {
check := DiskEncryptionCheck{}
assert.Equal(t, DiskEncryptionCheckName, check.Name())
}

View File

@@ -27,6 +27,8 @@ import (
"gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -123,6 +125,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
&types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
&networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{},
&zones.Zone{}, &records.Record{},
)
if err != nil {
return nil, fmt.Errorf("auto migratePreAuto: %w", err)
@@ -4179,3 +4182,184 @@ func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingS
return userID, nil
}
func (s *SqlStore) CreateZone(ctx context.Context, zone *zones.Zone) error {
result := s.db.Create(zone)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create zone to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to create zone to store")
}
return nil
}
func (s *SqlStore) UpdateZone(ctx context.Context, zone *zones.Zone) error {
result := s.db.Select("*").Save(zone)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update zone to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to update zone to store")
}
return nil
}
func (s *SqlStore) DeleteZone(ctx context.Context, accountID, zoneID string) error {
result := s.db.Delete(&zones.Zone{}, accountAndIDQueryCondition, accountID, zoneID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete zone from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete zone from store")
}
if result.RowsAffected == 0 {
return status.NewZoneNotFoundError(zoneID)
}
return nil
}
func (s *SqlStore) GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var zone *zones.Zone
result := tx.Preload("Records").Take(&zone, accountAndIDQueryCondition, accountID, zoneID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewZoneNotFoundError(zoneID)
}
log.WithContext(ctx).Errorf("failed to get zone from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get zone from store")
}
return zone, nil
}
func (s *SqlStore) GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error) {
var zone *zones.Zone
result := s.db.Where("account_id = ? AND domain = ?", accountID, domain).First(&zone)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewZoneNotFoundError(domain)
}
log.WithContext(ctx).Errorf("failed to get zone by domain from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get zone by domain from store")
}
return zone, nil
}
func (s *SqlStore) GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var zones []*zones.Zone
result := tx.Preload("Records").Find(&zones, accountIDCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get zones from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get zones from store")
}
return zones, nil
}
func (s *SqlStore) CreateDNSRecord(ctx context.Context, record *records.Record) error {
result := s.db.Create(record)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create dns record to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to create dns record to store")
}
return nil
}
func (s *SqlStore) UpdateDNSRecord(ctx context.Context, record *records.Record) error {
result := s.db.Select("*").Save(record)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update dns record to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to update dns record to store")
}
return nil
}
func (s *SqlStore) DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error {
result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete dns record from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete dns record from store")
}
if result.RowsAffected == 0 {
return status.NewDNSRecordNotFoundError(recordID)
}
return nil
}
func (s *SqlStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var record *records.Record
result := tx.Where("account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID).Take(&record)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewDNSRecordNotFoundError(recordID)
}
log.WithContext(ctx).Errorf("failed to get dns record from store: %v", result.Error)
return nil, status.Errorf(status.Internal, "failed to get dns record from store")
}
return record, nil
}
func (s *SqlStore) GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var recordsList []*records.Record
result := tx.Where("account_id = ? AND zone_id = ?", accountID, zoneID).Find(&recordsList)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get zone dns records from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get zone dns records from store")
}
return recordsList, nil
}
func (s *SqlStore) GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var recordsList []*records.Record
result := tx.Where("account_id = ? AND zone_id = ? AND name = ?", accountID, zoneID, name).Find(&recordsList)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get zone dns records by name from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get zone dns records by name from store")
}
return recordsList, nil
}
func (s *SqlStore) DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error {
result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ?", accountID, zoneID)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete zone dns records from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete zone dns records from store")
}
return nil
}

View File

@@ -22,6 +22,8 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -4025,3 +4027,476 @@ func TestSqlStore_ExecuteInTransaction_Timeout(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "transaction has already been committed or rolled back", "expected transaction rolled back error, got: %v", err)
}
func TestSqlStore_CreateZone(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
require.NotNil(t, savedZone)
assert.Equal(t, zone.ID, savedZone.ID)
assert.Equal(t, zone.Name, savedZone.Name)
assert.Equal(t, zone.Domain, savedZone.Domain)
assert.Equal(t, zone.Enabled, savedZone.Enabled)
assert.Equal(t, zone.EnableSearchDomain, savedZone.EnableSearchDomain)
assert.Equal(t, zone.DistributionGroups, savedZone.DistributionGroups)
}
func TestSqlStore_GetZoneByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
tests := []struct {
name string
accountID string
zoneID string
expectError bool
}{
{
name: "retrieve existing zone",
accountID: accountID,
zoneID: zone.ID,
expectError: false,
},
{
name: "retrieve non-existing zone",
accountID: accountID,
zoneID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty zone ID",
accountID: accountID,
zoneID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, savedZone)
} else {
require.NoError(t, err)
require.NotNil(t, savedZone)
assert.Equal(t, tt.zoneID, savedZone.ID)
}
})
}
}
func TestSqlStore_GetAccountZones(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone1 := zones.NewZone(accountID, "Zone 1", "example1.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone1)
require.NoError(t, err)
zone2 := zones.NewZone(accountID, "Zone 2", "example2.com", true, true, []string{"group1", "group2"})
err = store.CreateZone(context.Background(), zone2)
require.NoError(t, err)
allZones, err := store.GetAccountZones(context.Background(), LockingStrengthNone, accountID)
require.NoError(t, err)
require.NotNil(t, allZones)
assert.GreaterOrEqual(t, len(allZones), 2)
zoneIDs := make(map[string]bool)
for _, z := range allZones {
zoneIDs[z.ID] = true
}
assert.True(t, zoneIDs[zone1.ID])
assert.True(t, zoneIDs[zone2.ID])
}
func TestSqlStore_GetZoneByDomain(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
otherAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3c"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
tests := []struct {
name string
accountID string
domain string
expectError bool
errorType status.Type
}{
{
name: "retrieve existing zone by domain",
accountID: accountID,
domain: "example.com",
expectError: false,
},
{
name: "retrieve non-existing zone domain",
accountID: accountID,
domain: "non-existing.com",
expectError: true,
errorType: status.NotFound,
},
{
name: "retrieve with empty domain",
accountID: accountID,
domain: "",
expectError: true,
errorType: status.NotFound,
},
{
name: "retrieve with different account ID",
accountID: otherAccountID,
domain: "example.com",
expectError: true,
errorType: status.NotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
savedZone, err := store.GetZoneByDomain(context.Background(), tt.accountID, tt.domain)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, tt.errorType, sErr.Type())
require.Nil(t, savedZone)
} else {
require.NoError(t, err)
require.NotNil(t, savedZone)
assert.Equal(t, tt.domain, savedZone.Domain)
assert.Equal(t, zone.ID, savedZone.ID)
assert.Equal(t, zone.Name, savedZone.Name)
}
})
}
}
func TestSqlStore_UpdateZone(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
zone.Name = "Updated Zone"
zone.Domain = "updated.com"
zone.Enabled = false
zone.EnableSearchDomain = true
zone.DistributionGroups = []string{"group2", "group3"}
err = store.UpdateZone(context.Background(), zone)
require.NoError(t, err)
updatedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
require.NotNil(t, updatedZone)
assert.Equal(t, "Updated Zone", updatedZone.Name)
assert.Equal(t, "updated.com", updatedZone.Domain)
assert.False(t, updatedZone.Enabled)
assert.True(t, updatedZone.EnableSearchDomain)
assert.Equal(t, []string{"group2", "group3"}, updatedZone.DistributionGroups)
}
func TestSqlStore_DeleteZone(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
err = store.DeleteZone(context.Background(), accountID, zone.ID)
require.NoError(t, err)
deletedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.Error(t, err)
require.Nil(t, deletedZone)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
}
func TestSqlStore_CreateDNSRecord(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record)
require.NoError(t, err)
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
require.NoError(t, err)
require.NotNil(t, savedRecord)
assert.Equal(t, record.ID, savedRecord.ID)
assert.Equal(t, record.Name, savedRecord.Name)
assert.Equal(t, record.Type, savedRecord.Type)
assert.Equal(t, record.Content, savedRecord.Content)
assert.Equal(t, record.TTL, savedRecord.TTL)
assert.Equal(t, zone.ID, savedRecord.ZoneID)
}
func TestSqlStore_GetDNSRecordByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record)
require.NoError(t, err)
tests := []struct {
name string
accountID string
zoneID string
recordID string
expectError bool
}{
{
name: "retrieve existing record",
accountID: accountID,
zoneID: zone.ID,
recordID: record.ID,
expectError: false,
},
{
name: "retrieve non-existing record",
accountID: accountID,
zoneID: zone.ID,
recordID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty record ID",
accountID: accountID,
zoneID: zone.ID,
recordID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID, tt.recordID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, savedRecord)
} else {
require.NoError(t, err)
require.NotNil(t, savedRecord)
assert.Equal(t, tt.recordID, savedRecord.ID)
}
})
}
}
func TestSqlStore_GetZoneDNSRecords(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
recordA := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), recordA)
require.NoError(t, err)
recordAAAA := records.NewRecord(accountID, zone.ID, "ipv6.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
err = store.CreateDNSRecord(context.Background(), recordAAAA)
require.NoError(t, err)
recordCNAME := records.NewRecord(accountID, zone.ID, "alias.example.com", records.RecordTypeCNAME, "www.example.com", 300)
err = store.CreateDNSRecord(context.Background(), recordCNAME)
require.NoError(t, err)
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
require.NotNil(t, allRecords)
assert.Equal(t, 3, len(allRecords))
recordIDs := make(map[string]bool)
for _, r := range allRecords {
recordIDs[r.ID] = true
}
assert.True(t, recordIDs[recordA.ID])
assert.True(t, recordIDs[recordAAAA.ID])
assert.True(t, recordIDs[recordCNAME.ID])
}
func TestSqlStore_GetZoneDNSRecordsByName(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record1)
require.NoError(t, err)
record2 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
err = store.CreateDNSRecord(context.Background(), record2)
require.NoError(t, err)
record3 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
err = store.CreateDNSRecord(context.Background(), record3)
require.NoError(t, err)
recordsByName, err := store.GetZoneDNSRecordsByName(context.Background(), LockingStrengthNone, accountID, zone.ID, "www.example.com")
require.NoError(t, err)
require.NotNil(t, recordsByName)
assert.Equal(t, 2, len(recordsByName))
for _, r := range recordsByName {
assert.Equal(t, "www.example.com", r.Name)
}
}
func TestSqlStore_UpdateDNSRecord(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record)
require.NoError(t, err)
record.Name = "api.example.com"
record.Content = "192.168.1.100"
record.TTL = 600
err = store.UpdateDNSRecord(context.Background(), record)
require.NoError(t, err)
updatedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
require.NoError(t, err)
require.NotNil(t, updatedRecord)
assert.Equal(t, "api.example.com", updatedRecord.Name)
assert.Equal(t, "192.168.1.100", updatedRecord.Content)
assert.Equal(t, 600, updatedRecord.TTL)
}
func TestSqlStore_DeleteDNSRecord(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record)
require.NoError(t, err)
err = store.DeleteDNSRecord(context.Background(), accountID, zone.ID, record.ID)
require.NoError(t, err)
deletedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
require.Error(t, err)
require.Nil(t, deletedRecord)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
}
func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record1)
require.NoError(t, err)
record2 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
err = store.CreateDNSRecord(context.Background(), record2)
require.NoError(t, err)
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
assert.Equal(t, 2, len(allRecords))
err = store.DeleteZoneDNSRecords(context.Background(), accountID, zone.ID)
require.NoError(t, err)
remainingRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
assert.Equal(t, 0, len(remainingRecords))
}

View File

@@ -23,6 +23,8 @@ import (
"gorm.io/gorm"
"github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types"
@@ -209,6 +211,21 @@ type Store interface {
// SetFieldEncrypt sets the field encryptor for encrypting sensitive user data.
SetFieldEncrypt(enc *crypt.FieldEncrypt)
GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error)
CreateZone(ctx context.Context, zone *zones.Zone) error
UpdateZone(ctx context.Context, zone *zones.Zone) error
DeleteZone(ctx context.Context, accountID, zoneID string) error
GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error)
GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error)
GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error)
CreateDNSRecord(ctx context.Context, record *records.Record) error
UpdateDNSRecord(ctx context.Context, record *records.Record) error
DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error
GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error)
GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error)
GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error)
DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error
}
const (

View File

@@ -18,6 +18,8 @@ import (
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -150,17 +152,16 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool {
// GetRoutesToSync returns the enabled routes for the peer ID and the routes
// from the ACL peers that have distribution groups associated with the peer ID.
// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID.
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route {
func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID)
peerRoutesMembership := make(LookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
}
groupListMap := a.GetPeerGroups(peerID)
for _, peer := range aclPeers {
activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID)
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap)
groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups)
filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
}
@@ -274,6 +275,7 @@ func (a *Account) GetPeerNetworkMap(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
@@ -294,6 +296,8 @@ func (a *Account) GetPeerNetworkMap(
}
}
peerGroups := a.GetPeerGroups(peerID)
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
@@ -307,7 +311,7 @@ func (a *Account) GetPeerNetworkMap(
peersToConnect = append(peersToConnect, p)
}
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect)
routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups)
routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers)
var networkResourcesFirewallRules []*RouteFirewallRule
@@ -323,6 +327,7 @@ func (a *Account) GetPeerNetworkMap(
if dnsManagementStatus {
var zones []nbdns.CustomZone
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers)
zones = append(zones, nbdns.CustomZone{
@@ -330,6 +335,10 @@ func (a *Account) GetPeerNetworkMap(
Records: records,
})
}
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
zones = append(zones, filteredAccountZones...)
dnsUpdate.CustomZones = zones
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
@@ -1881,3 +1890,66 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p
return filteredRecords
}
// filterPeerAppliedZones filters account zones based on the peer's group membership
func filterPeerAppliedZones(ctx context.Context, accountZones []*zones.Zone, peerGroups LookupMap) []nbdns.CustomZone {
var customZones []nbdns.CustomZone
if len(peerGroups) == 0 {
return customZones
}
for _, zone := range accountZones {
if !zone.Enabled || len(zone.Records) == 0 {
continue
}
hasAccess := false
for _, distGroupID := range zone.DistributionGroups {
if _, found := peerGroups[distGroupID]; found {
hasAccess = true
break
}
}
if !hasAccess {
continue
}
simpleRecords := make([]nbdns.SimpleRecord, 0, len(zone.Records))
for _, record := range zone.Records {
var recordType int
rData := record.Content
switch record.Type {
case records.RecordTypeA:
recordType = int(dns.TypeA)
case records.RecordTypeAAAA:
recordType = int(dns.TypeAAAA)
case records.RecordTypeCNAME:
recordType = int(dns.TypeCNAME)
rData = dns.Fqdn(record.Content)
default:
log.WithContext(ctx).Warnf("unknown DNS record type %s for record %s", record.Type, record.ID)
continue
}
simpleRecords = append(simpleRecords, nbdns.SimpleRecord{
Name: dns.Fqdn(record.Name),
Type: recordType,
Class: nbdns.DefaultClass,
TTL: record.TTL,
RData: rData,
})
}
customZones = append(customZones, nbdns.CustomZone{
Domain: dns.Fqdn(zone.Domain),
Records: simpleRecords,
SearchDomainDisabled: !zone.EnableSearchDomain,
NonAuthoritative: true,
})
}
return customZones
}

View File

@@ -13,6 +13,8 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -1425,3 +1427,515 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) {
})
}
}
func Test_filterPeerAppliedZones(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
accountZones []*zones.Zone
peerGroups LookupMap
expected []nbdns.CustomZone
}{
{
name: "empty peer groups returns empty custom zones",
accountZones: []*zones.Zone{},
peerGroups: LookupMap{},
expected: []nbdns.CustomZone{},
},
{
name: "peer has access to zone with A record",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "example.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.example.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
},
SearchDomainDisabled: true,
},
},
},
{
name: "peer has access to zone with search domain enabled",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "internal.local",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "api.internal.local",
Type: records.RecordTypeA,
Content: "10.0.0.1",
TTL: 600,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "internal.local.",
Records: []nbdns.SimpleRecord{
{
Name: "api.internal.local.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 600,
RData: "10.0.0.1",
},
},
SearchDomainDisabled: false,
},
},
},
{
name: "peer has no access to zone",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "private.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group2"},
Records: []*records.Record{
{
ID: "record1",
Name: "secret.private.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{},
},
{
name: "disabled zone is filtered out",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "disabled.com",
Enabled: false,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.disabled.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{},
},
{
name: "zone with no records is filtered out",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "empty.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{},
},
{
name: "peer has access via multiple groups",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "multi.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1", "group2", "group3"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.multi.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group2": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "multi.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.multi.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
},
SearchDomainDisabled: true,
},
},
},
{
name: "multiple zones with mixed access",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "allowed.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.allowed.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
{
ID: "zone2",
Domain: "denied.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group2"},
Records: []*records.Record{
{
ID: "record2",
Name: "www.denied.com",
Type: records.RecordTypeA,
Content: "192.168.1.2",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "allowed.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.allowed.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
},
SearchDomainDisabled: true,
},
},
},
{
name: "zone with multiple record types",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "mixed.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.mixed.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
{
ID: "record2",
Name: "ipv6.mixed.com",
Type: records.RecordTypeAAAA,
Content: "2001:db8::1",
TTL: 600,
},
{
ID: "record3",
Name: "alias.mixed.com",
Type: records.RecordTypeCNAME,
Content: "www.mixed.com",
TTL: 900,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "mixed.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.mixed.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
{
Name: "ipv6.mixed.com.",
Type: int(dns.TypeAAAA),
Class: nbdns.DefaultClass,
TTL: 600,
RData: "2001:db8::1",
},
{
Name: "alias.mixed.com.",
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 900,
RData: "www.mixed.com.",
},
},
SearchDomainDisabled: true,
},
},
},
{
name: "multiple zones both accessible",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "first.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.first.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
{
ID: "zone2",
Domain: "second.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record2",
Name: "www.second.com",
Type: records.RecordTypeA,
Content: "192.168.1.2",
TTL: 600,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "first.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.first.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
},
SearchDomainDisabled: false,
},
{
Domain: "second.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.second.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 600,
RData: "192.168.1.2",
},
},
SearchDomainDisabled: true,
},
},
},
{
name: "zone with multiple records of same type",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "multi-a.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.multi-a.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
{
ID: "record2",
Name: "www.multi-a.com",
Type: records.RecordTypeA,
Content: "192.168.1.2",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "multi-a.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.multi-a.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
{
Name: "www.multi-a.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.2",
},
},
SearchDomainDisabled: true,
},
},
},
{
name: "peer in multiple groups accessing different zones",
accountZones: []*zones.Zone{
{
ID: "zone1",
Domain: "zone1.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
Records: []*records.Record{
{
ID: "record1",
Name: "www.zone1.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
},
},
},
{
ID: "zone2",
Domain: "zone2.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group2"},
Records: []*records.Record{
{
ID: "record2",
Name: "www.zone2.com",
Type: records.RecordTypeA,
Content: "192.168.1.2",
TTL: 300,
},
},
},
},
peerGroups: LookupMap{"group1": struct{}{}, "group2": struct{}{}},
expected: []nbdns.CustomZone{
{
Domain: "zone1.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.zone1.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.1",
},
},
SearchDomainDisabled: true,
},
{
Domain: "zone2.com.",
Records: []nbdns.SimpleRecord{
{
Name: "www.zone2.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.2",
},
},
SearchDomainDisabled: true,
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterPeerAppliedZones(ctx, tt.accountZones, tt.peerGroups)
require.Equal(t, len(tt.expected), len(result), "number of custom zones should match")
for i, expectedZone := range tt.expected {
assert.Equal(t, expectedZone.Domain, result[i].Domain, "domain should match")
assert.Equal(t, expectedZone.SearchDomainDisabled, result[i].SearchDomainDisabled, "search domain disabled flag should match")
assert.Equal(t, len(expectedZone.Records), len(result[i].Records), "number of records should match")
for j, expectedRecord := range expectedZone.Records {
assert.Equal(t, expectedRecord.Name, result[i].Records[j].Name, "record name should match")
assert.Equal(t, expectedRecord.Type, result[i].Records[j].Type, "record type should match")
assert.Equal(t, expectedRecord.Class, result[i].Records[j].Class, "record class should match")
assert.Equal(t, expectedRecord.TTL, result[i].Records[j].TTL, "record TTL should match")
assert.Equal(t, expectedRecord.RData, result[i].Records[j].RData, "record RData should match")
}
}
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -25,11 +26,12 @@ func (a *Account) GetPeerNetworkMapExp(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeers map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
a.initNetworkMapBuilder(validatedPeers)
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics)
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {

View File

@@ -70,13 +70,13 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
@@ -115,7 +115,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
b.Run("old builder", func(b *testing.B) {
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -124,7 +124,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
for range b.N {
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
for _, peerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil)
_ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
@@ -177,7 +177,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -185,7 +185,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
err = builder.OnPeerAddedIncremental(account, newPeerID)
require.NoError(t, err, "error adding peer to cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
@@ -240,7 +240,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -250,7 +250,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerAddedIncremental(account, newPeerID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
@@ -317,7 +317,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -325,7 +325,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
err = builder.OnPeerAddedIncremental(account, newRouterID)
require.NoError(t, err, "error adding router to cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
@@ -402,7 +402,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -412,7 +412,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerAddedIncremental(account, newRouterID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
@@ -458,7 +458,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -466,7 +466,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
err = builder.OnPeerDeleted(account, deletedPeerID)
require.NoError(t, err, "error deleting peer from cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
@@ -537,7 +537,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
@@ -545,7 +545,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
err = builder.OnPeerDeleted(account, deletedRouterID)
require.NoError(t, err, "error deleting routing peer from cache")
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
@@ -597,7 +597,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
b.Run("old builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -607,7 +607,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerDeleted(account, deletedPeerID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
}
}
})
@@ -944,7 +944,7 @@ func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T
time.Sleep(100 * time.Millisecond)
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil)
normalizeAndSortNetworkMap(networkMap)

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -1033,7 +1034,7 @@ func (b *NetworkMapBuilder) updateAccountLocked(account *Account) *Account {
}
func (b *NetworkMapBuilder) GetPeerNetworkMap(
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone,
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone,
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
@@ -1057,7 +1058,7 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
return &NetworkMap{Network: account.Network.Copy()}
}
nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers)
nm := b.assembleNetworkMap(ctx, account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, accountZones, validatedPeers)
if metrics != nil {
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
@@ -1074,8 +1075,8 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
}
func (b *NetworkMapBuilder) assembleNetworkMap(
account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
ctx context.Context, account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
dnsConfig *nbdns.Config, sshView *PeerSSHView, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, validatedPeers map[string]struct{},
) *NetworkMap {
var peersToConnect []*nbpeer.Peer
@@ -1125,13 +1126,26 @@ func (b *NetworkMapBuilder) assembleNetworkMap(
}
finalDNSConfig := *dnsConfig
if finalDNSConfig.ServiceEnable && customZone.Domain != "" {
if finalDNSConfig.ServiceEnable {
var zones []nbdns.CustomZone
records := filterZoneRecordsForPeers(peer, customZone, peersToConnect, expiredPeers)
peerGroupsSlice := b.cache.peerToGroups[peer.ID]
peerGroups := make(LookupMap, len(peerGroupsSlice))
for _, groupID := range peerGroupsSlice {
peerGroups[groupID] = struct{}{}
}
if peersCustomZone.Domain != "" {
records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect, expiredPeers)
zones = append(zones, nbdns.CustomZone{
Domain: customZone.Domain,
Domain: peersCustomZone.Domain,
Records: records,
})
}
filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups)
zones = append(zones, filteredAccountZones...)
finalDNSConfig.CustomZones = zones
}

View File

@@ -911,10 +911,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
accountUsers := []*types.User{}
switch {
case allowed:
start := time.Now()
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
log.WithContext(ctx).Tracef("Got %d users from account %s after %s", len(accountUsers), accountID, time.Since(start))
case user != nil && user.AccountID == accountID:
accountUsers = append(accountUsers, user)
default:
@@ -933,23 +935,40 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
users := make(map[string]userLoggedInOnce, len(accountUsers))
usersFromIntegration := make([]*idp.UserData, 0)
filtered := make(map[string]*idp.UserData, len(accountUsers))
log.WithContext(ctx).Tracef("Querying users from IDP for account %s", accountID)
start := time.Now()
integrationKeys := make(map[string]struct{})
for _, user := range accountUsers {
if user.Issued == types.UserIssuedIntegration {
key := user.IntegrationReference.CacheKey(accountID, user.Id)
info, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
log.WithContext(ctx).Infof("Get ExternalCache for key: %s, error: %s", key, err)
users[user.Id] = true
continue
}
usersFromIntegration = append(usersFromIntegration, info)
integrationKeys[user.IntegrationReference.CacheKey(accountID)] = struct{}{}
continue
}
if !user.IsServiceUser {
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
}
}
for key := range integrationKeys {
usersData, err := am.externalCacheManager.GetUsers(am.ctx, key)
if err != nil {
log.WithContext(ctx).Debugf("GetUsers from ExternalCache for key: %s, error: %s", key, err)
continue
}
for _, ud := range usersData {
filtered[ud.ID] = ud
}
}
for _, ud := range filtered {
usersFromIntegration = append(usersFromIntegration, ud)
}
log.WithContext(ctx).Tracef("Got user info from external cache after %s", time.Since(start))
start = time.Now()
queriedUsers, err = am.lookupCache(ctx, users, accountID)
log.WithContext(ctx).Tracef("Got user info from cache for %d users after %s", len(queriedUsers), time.Since(start))
if err != nil {
return nil, err
}

View File

@@ -1086,8 +1086,12 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
assert.NoError(t, err)
cacheManager := am.GetExternalCacheManager()
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}, time.Minute)
tud := &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}
cacheKeyUser := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
err = cacheManager.Set(context.Background(), cacheKeyUser, tud, time.Minute)
assert.NoError(t, err)
cacheKeyAccount := externalUser.IntegrationReference.CacheKey(mockAccountID)
err = cacheManager.SetUsers(context.Background(), cacheKeyAccount, []*idp.UserData{tud}, time.Minute)
assert.NoError(t, err)
infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)

View File

@@ -1,5 +1,9 @@
package util
import "regexp"
var domainRegex = regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`)
// Difference returns the elements in `a` that aren't in `b`.
func Difference(a, b []string) []string {
mb := make(map[string]struct{}, len(b))
@@ -50,3 +54,10 @@ func contains[T comparableObject[T]](slice []T, element T) bool {
}
return false
}
func IsValidDomain(domain string) bool {
if domain == "" {
return false
}
return domainRegex.MatchString(domain)
}

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
@@ -22,6 +23,7 @@ import (
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/stun"
"github.com/netbirdio/netbird/util"
)
@@ -43,6 +45,10 @@ type Config struct {
LogLevel string
LogFile string
HealthcheckListenAddress string
// STUN server configuration
EnableSTUN bool
STUNPorts []int
STUNLogLevel string
}
func (c Config) Validate() error {
@@ -52,6 +58,25 @@ func (c Config) Validate() error {
if c.AuthSecret == "" {
return fmt.Errorf("auth secret is required")
}
// Validate STUN configuration
if c.EnableSTUN {
if len(c.STUNPorts) == 0 {
return fmt.Errorf("--stun-ports is required when --enable-stun is set")
}
seen := make(map[int]bool)
for _, port := range c.STUNPorts {
if port <= 0 || port > 65535 {
return fmt.Errorf("invalid STUN port %d: must be between 1 and 65535", port)
}
if seen[port] {
return fmt.Errorf("duplicate STUN port %d", port)
}
seen[port] = true
}
}
return nil
}
@@ -91,6 +116,9 @@ func init() {
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
rootCmd.PersistentFlags().BoolVar(&cobraConfig.EnableSTUN, "enable-stun", false, "enable embedded STUN server")
rootCmd.PersistentFlags().IntSliceVar(&cobraConfig.STUNPorts, "stun-ports", []int{3478}, "ports for the embedded STUN server (can be specified multiple times or comma-separated)")
rootCmd.PersistentFlags().StringVar(&cobraConfig.STUNLogLevel, "stun-log-level", "info", "log level for STUN server (panic, fatal, error, warn, info, debug, trace)")
setFlagsFromEnvVars(rootCmd)
}
@@ -119,21 +147,14 @@ func execute(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to initialize log: %s", err)
}
// Resource creation phase (fail fast before starting any goroutines)
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
if err != nil {
log.Debugf("setup metrics: %v", err)
return fmt.Errorf("setup metrics: %v", err)
}
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start metrics server: %v", err)
}
}()
srvListenerCfg := server.ListenerConfig{
Address: cobraConfig.ListenAddress,
}
@@ -145,6 +166,12 @@ func execute(cmd *cobra.Command, args []string) error {
}
srvListenerCfg.TLSConfig = tlsConfig
// Create STUN listeners early to fail fast
stunListeners, err := createSTUNListeners()
if err != nil {
return err
}
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
@@ -155,60 +182,145 @@ func execute(cmd *cobra.Command, args []string) error {
TLSSupport: tlsSupport,
}
srv, err := server.NewServer(cfg)
srv, err := createRelayServer(cfg)
if err != nil {
log.Debugf("failed to create relay server: %v", err)
return fmt.Errorf("failed to create relay server: %v", err)
cleanupSTUNListeners(stunListeners)
return err
}
hCfg := healthcheck.Config{
ListenAddress: cobraConfig.HealthcheckListenAddress,
ServiceChecker: srv,
}
httpHealthcheck, err := createHealthCheck(hCfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return err
}
var stunServer *stun.Server
if len(stunListeners) > 0 {
stunServer = stun.NewServer(stunListeners, cobraConfig.STUNLogLevel)
}
// Start all servers (only after all resources are successfully created)
startServers(&wg, metricsServer, srv, srvListenerCfg, httpHealthcheck, stunServer)
waitForExitSignal()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = shutdownServers(ctx, metricsServer, srv, httpHealthcheck, stunServer)
wg.Wait()
return err
}
func startServers(wg *sync.WaitGroup, metricsServer *metrics.Metrics, srv *server.Server, srvListenerCfg server.ListenerConfig, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) {
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start metrics server: %v", err)
}
}()
instanceURL := srv.InstanceURL()
log.Infof("server will be available on: %s", instanceURL.String())
wg.Add(1)
go func() {
defer wg.Done()
if err := srv.Listen(srvListenerCfg); err != nil {
log.Fatalf("failed to bind server: %s", err)
log.Fatalf("failed to bind relay server: %s", err)
}
}()
hCfg := healthcheck.Config{
ListenAddress: cobraConfig.HealthcheckListenAddress,
ServiceChecker: srv,
}
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
log.Debugf("failed to create healthcheck server: %v", err)
return fmt.Errorf("failed to create healthcheck server: %v", err)
}
wg.Add(1)
go func() {
defer wg.Done()
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start healthcheck server: %v", err)
log.Fatalf("failed to start healthcheck server: %v", err)
}
}()
// it will block until exit signal
waitForExitSignal()
if stunServer != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := stunServer.Listen(); err != nil {
if errors.Is(err, stun.ErrServerClosed) {
return
}
log.Errorf("STUN server error: %v", err)
}
}()
}
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
func shutdownServers(ctx context.Context, metricsServer *metrics.Metrics, srv *server.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) error {
var errs error
var shutDownErrors error
if err := httpHealthcheck.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err))
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
}
if stunServer != nil {
if err := stunServer.Shutdown(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
}
}
if err := srv.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
}
log.Infof("shutting down metrics server")
if err := metricsServer.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
}
wg.Wait()
return shutDownErrors
return errs
}
func createHealthCheck(hCfg healthcheck.Config) (*healthcheck.Server, error) {
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
log.Debugf("failed to create healthcheck server: %v", err)
return nil, fmt.Errorf("failed to create healthcheck server: %v", err)
}
return httpHealthcheck, nil
}
func createRelayServer(cfg server.Config) (*server.Server, error) {
srv, err := server.NewServer(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create relay server: %v", err)
}
return srv, nil
}
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
for _, l := range stunListeners {
_ = l.Close()
}
}
func createSTUNListeners() ([]*net.UDPConn, error) {
var stunListeners []*net.UDPConn
if cobraConfig.EnableSTUN {
for _, port := range cobraConfig.STUNPorts {
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
if err != nil {
// Close already opened listeners on failure
cleanupSTUNListeners(stunListeners)
log.Debugf("failed to create STUN listener on port %d: %v", port, err)
return nil, fmt.Errorf("failed to create STUN listener on port %d: %v", port, err)
}
stunListeners = append(stunListeners, listener)
}
}
return stunListeners, nil
}
func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {

View File

@@ -8,6 +8,10 @@ pid="$(pgrep -x -f /usr/bin/netbird-ui || true)"
if [ -n "${pid}" ]
then
uid="$(cat /proc/"${pid}"/loginuid)"
# loginuid can be 4294967295 (-1) if not set, fall back to process uid
if [ "${uid}" = "4294967295" ] || [ "${uid}" = "-1" ]; then
uid="$(stat -c '%u' /proc/"${pid}")"
fi
username="$(id -nu "${uid}")"
# Only re-run if it was already running
pkill -x -f /usr/bin/netbird-ui >/dev/null 2>&1

Some files were not shown because too many files have changed in this diff Show More