Compare commits

...

8 Commits

Author SHA1 Message Date
hakansa
70db8751d7 [client] Add --disable-update-settings flag to the service (#4335)
[client] Add --disable-update-settings flag to the service (#4335)
2025-08-13 21:05:12 +03:00
Zoltan Papp
86a00ab4af Fix Go tarball version in FreeBSD build configuration (#4339) 2025-08-13 13:52:11 +02:00
Zoltan Papp
3d4b502126 [server] Add health check HTTP endpoint for Relay server (#4297)
The health check endpoint listens on a dedicated HTTP server.
By default, it is available at 0.0.0.0:9000/health. This can be configured using the --health-listen-address flag.

The results are cached for 3 seconds to avoid excessive calls.

The health check performs the following:

Checks the number of active listeners.
Validates each listener via WebSocket and QUIC dials, including TLS certificate verification.
2025-08-13 10:40:04 +02:00
Bethuel Mmbaga
a4e8647aef [management] Enable flow groups (#4230)
Adds the ability to limit traffic events logging to specific peer groups
2025-08-13 00:00:40 +03:00
Viktor Liu
160b811e21 [client] Distinguish between NXDOMAIN and NODATA in the dns forwarder (#4321) 2025-08-12 15:59:42 +02:00
Viktor Liu
5e607cf4e9 [client] Skip dns upstream servers pointing to our dns server IP to prevent loops (#4330) 2025-08-12 15:41:23 +02:00
Viktor Liu
0fdb944058 [client] Create NRPT rules separately per domain (#4329) 2025-08-12 15:40:37 +02:00
Zoltan Papp
ccbabd9e2a Add pprof support for Relay server (#4203) 2025-08-12 12:24:24 +02:00
53 changed files with 1187 additions and 158 deletions

View File

@@ -25,8 +25,7 @@ jobs:
release: "14.2"
prepare: |
pkg install -y curl pkgconf xorg
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"

View File

@@ -73,6 +73,7 @@ var (
dnsRouteInterval time.Duration
lazyConnEnabled bool
profilesDisabled bool
updateSettingsDisabled bool
rootCmd = &cobra.Command{
Use: "netbird",

View File

@@ -42,7 +42,8 @@ func init() {
}
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile.")
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` +

View File

@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
}
}
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled)
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err)
}

View File

@@ -49,6 +49,14 @@ func buildServiceArguments() []string {
args = append(args, "--log-file", logFile)
}
if profilesDisabled {
args = append(args, "--disable-profiles")
}
if updateSettingsDisabled {
args = append(args, "--disable-update-settings")
}
return args
}

View File

@@ -11,6 +11,7 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -26,8 +27,8 @@ import (
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
mgmt "github.com/netbirdio/netbird/management/server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
sigProto "github.com/netbirdio/netbird/shared/signal/proto"
sig "github.com/netbirdio/netbird/signal/server"
)
@@ -97,6 +98,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
@@ -108,7 +110,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)
@@ -134,7 +136,7 @@ func startClientDaemon(
s := grpc.NewServer()
server := client.New(ctx,
"", "", false)
"", "", false, false)
if err := server.Start(); err != nil {
t.Fatal(err)
}

View File

@@ -176,4 +176,3 @@ nameserver 192.168.0.1
t.Errorf("unexpected resolv.conf content: %v", cfg)
}
}

View File

@@ -64,9 +64,10 @@ const (
)
type registryConfigurator struct {
guid string
routingAll bool
gpo bool
guid string
routingAll bool
gpo bool
nrptEntryCount int
}
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@@ -177,7 +178,11 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
}
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil {
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
@@ -193,13 +198,24 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
}
if len(matchDomains) != 0 {
if err := r.addDNSMatchPolicy(matchDomains, config.ServerIP); err != nil {
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
if err != nil {
return fmt.Errorf("add dns match policy: %w", err)
}
r.nrptEntryCount = count
} else {
if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("remove dns match policies: %w", err)
}
r.nrptEntryCount = 0
}
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
if err := r.updateSearchDomains(searchDomains); err != nil {
@@ -220,28 +236,34 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
return nil
}
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) error {
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
if r.gpo {
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, domains, ip); err != nil {
return fmt.Errorf("configure GPO DNS policy: %w", err)
for i, domain := range domains {
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
if r.gpo {
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
}
singleDomain := []string{domain}
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
}
log.Debugf("added NRPT entry for domain: %s", domain)
}
if r.gpo {
if err := refreshGroupPolicy(); err != nil {
log.Warnf("failed to refresh group policy: %v", err)
}
} else {
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
return fmt.Errorf("configure local DNS policy: %w", err)
}
}
log.Infof("added %d match domains. Domain list: %s", len(domains), domains)
return nil
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
return len(domains), nil
}
// configureDNSPolicy handles the actual configuration of a DNS policy at the specified path
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
return fmt.Errorf("remove existing dns policy: %w", err)
@@ -374,12 +396,25 @@ func (r *registryConfigurator) restoreHostDNS() error {
func (r *registryConfigurator) removeDNSMatchPolicies() error {
var merr *multierror.Error
// Try to remove the base entries (for backward compatibility)
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local registry key: %w", err))
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO registry key: %w", err))
for i := 0; i < r.nrptEntryCount; i++ {
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
}
}
if err := refreshGroupPolicy(); err != nil {

View File

@@ -695,6 +695,12 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue
}
if ns.IP == s.service.RuntimeIP() {
log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP)
continue
}
handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort())
}

View File

@@ -2056,3 +2056,124 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}
func TestDNSLoopPrevention(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
dnsServerIP := service.RuntimeIP()
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgInterface,
service: service,
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
tests := []struct {
name string
nsGroups []*nbdns.NameServerGroup
expectedHandlers int
expectedServers []netip.Addr
shouldFilterOwnIP bool
}{
{
name: "FilterOwnDNSServerIP",
nsGroups: []*nbdns.NameServerGroup{
{
Primary: true,
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{},
},
},
expectedHandlers: 1,
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
shouldFilterOwnIP: true,
},
{
name: "AllServersFiltered",
nsGroups: []*nbdns.NameServerGroup{
{
Primary: false,
NameServers: []nbdns.NameServer{
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
},
expectedHandlers: 0,
expectedServers: []netip.Addr{},
shouldFilterOwnIP: true,
},
{
name: "MixedServersWithOwnIP",
nsGroups: []*nbdns.NameServerGroup{
{
Primary: false,
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53}, // duplicate
},
Domains: []string{"test.com"},
},
},
expectedHandlers: 1,
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
shouldFilterOwnIP: true,
},
{
name: "NoOwnIPInList",
nsGroups: []*nbdns.NameServerGroup{
{
Primary: true,
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{},
},
},
expectedHandlers: 1,
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
shouldFilterOwnIP: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
muxUpdates, err := server.buildUpstreamHandlerUpdate(tt.nsGroups)
assert.NoError(t, err)
assert.Len(t, muxUpdates, tt.expectedHandlers)
if tt.expectedHandlers > 0 {
handler := muxUpdates[0].handler.(*upstreamResolver)
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
if tt.shouldFilterOwnIP {
for _, upstream := range handler.upstreamServers {
assert.NotEqual(t, dnsServerIP, upstream.Addr())
}
}
for _, expected := range tt.expectedServers {
found := false
for _, upstream := range handler.upstreamServers {
if upstream.Addr() == expected {
found = true
break
}
}
assert.True(t, found, "Expected server %s not found", expected)
}
}
})
}
}

View File

@@ -5,8 +5,9 @@ import (
)
type ShutdownState struct {
Guid string
GPO bool
Guid string
GPO bool
NRPTEntryCount int
}
func (s *ShutdownState) Name() string {
@@ -15,8 +16,9 @@ func (s *ShutdownState) Name() string {
func (s *ShutdownState) Cleanup() error {
manager := &registryConfigurator{
guid: s.Guid,
gpo: s.GPO,
guid: s.Guid,
gpo: s.GPO,
nrptEntryCount: s.NRPTEntryCount,
}
if err := manager.restoreUncleanShutdownDNS(); err != nil {

View File

@@ -165,7 +165,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
defer cancel()
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil {
f.handleDNSError(w, query, resp, domain, err)
f.handleDNSError(ctx, w, question, resp, domain, err)
return nil
}
@@ -244,20 +244,57 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
}
}
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
//
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
// only handles A/AAAA queries and returns NOTIMP for other types.
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
resp.Rcode = dns.RcodeNameError
return
}
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
resp.Rcode = dns.RcodeNameError
return
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
resp.Rcode = dns.RcodeSuccess
return
}
// Alternative query succeeded - domain exists but has no records of this type
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) {
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
var dnsErr *net.DNSError
switch {
case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound {
// Pass through NXDOMAIN
resp.Rcode = dns.RcodeNameError
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
}
if dnsErr.Server != "" {
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err)
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err)
}

View File

@@ -3,6 +3,7 @@ package dnsfwd
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"testing"
@@ -16,8 +17,8 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
func Test_getMatchingEntries(t *testing.T) {
@@ -708,6 +709,131 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
assert.Len(t, matches, 3, "Should match 3 patterns")
}
// TestDNSForwarder_NodataVsNxdomain tests that the forwarder correctly distinguishes
// between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of that type)
func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res", Set: set}}
forwarder.UpdateDomains(entries)
tests := []struct {
name string
queryType uint16
setupMocks func()
expectedCode int
expectNoAnswer bool // true if we expect NOERROR with empty answer (NODATA case)
description string
}{
{
name: "domain exists but no AAAA records (NODATA)",
queryType: dns.TypeAAAA,
setupMocks: func() {
// First query for AAAA returns not found
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
// Check query for A records succeeds (domain exists)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
},
expectedCode: dns.RcodeSuccess,
expectNoAnswer: true,
description: "Should return NOERROR when domain exists but has no records of requested type",
},
{
name: "domain exists but no A records (NODATA)",
queryType: dns.TypeA,
setupMocks: func() {
// First query for A returns not found
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
// Check query for AAAA records succeeds (domain exists)
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
Return([]netip.Addr{netip.MustParseAddr("2001:db8::1")}, nil).Once()
},
expectedCode: dns.RcodeSuccess,
expectNoAnswer: true,
description: "Should return NOERROR when domain exists but has no A records",
},
{
name: "domain doesn't exist (NXDOMAIN)",
queryType: dns.TypeA,
setupMocks: func() {
// First query for A returns not found
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
// Check query for AAAA also returns not found (domain doesn't exist)
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
},
expectedCode: dns.RcodeNameError,
expectNoAnswer: true,
description: "Should return NXDOMAIN when domain doesn't exist at all",
},
{
name: "domain exists with records (normal success)",
queryType: dns.TypeA,
setupMocks: func() {
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
// Expect firewall update for successful resolution
expectedPrefix := netip.PrefixFrom(netip.MustParseAddr("1.2.3.4"), 32)
mockFirewall.On("UpdateSet", set, []netip.Prefix{expectedPrefix}).Return(nil).Once()
},
expectedCode: dns.RcodeSuccess,
expectNoAnswer: false,
description: "Should return NOERROR with answer when records exist",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset mock expectations
mockResolver.ExpectedCalls = nil
mockResolver.Calls = nil
mockFirewall.ExpectedCalls = nil
mockFirewall.Calls = nil
tt.setupMocks()
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
writtenResp = resp
}
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
if tt.expectNoAnswer {
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
}
mockResolver.AssertExpectations(t)
})
}
}
func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})

View File

@@ -27,6 +27,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
@@ -1564,13 +1565,14 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
AnyTimes()
permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil {
return nil, "", err

View File

@@ -4430,6 +4430,94 @@ func (*LogoutResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{66}
}
type GetFeaturesRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GetFeaturesRequest) Reset() {
*x = GetFeaturesRequest{}
mi := &file_daemon_proto_msgTypes[67]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GetFeaturesRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetFeaturesRequest) ProtoMessage() {}
func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[67]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetFeaturesRequest.ProtoReflect.Descriptor instead.
func (*GetFeaturesRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{67}
}
type GetFeaturesResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
DisableProfiles bool `protobuf:"varint,1,opt,name=disable_profiles,json=disableProfiles,proto3" json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `protobuf:"varint,2,opt,name=disable_update_settings,json=disableUpdateSettings,proto3" json:"disable_update_settings,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GetFeaturesResponse) Reset() {
*x = GetFeaturesResponse{}
mi := &file_daemon_proto_msgTypes[68]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GetFeaturesResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetFeaturesResponse) ProtoMessage() {}
func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[68]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetFeaturesResponse.ProtoReflect.Descriptor instead.
func (*GetFeaturesResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{68}
}
func (x *GetFeaturesResponse) GetDisableProfiles() bool {
if x != nil {
return x.DisableProfiles
}
return false
}
func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool {
if x != nil {
return x.DisableUpdateSettings
}
return false
}
type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -4440,7 +4528,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[68]
mi := &file_daemon_proto_msgTypes[70]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4452,7 +4540,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[68]
mi := &file_daemon_proto_msgTypes[70]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4872,7 +4960,11 @@ const file_daemon_proto_rawDesc = "" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" +
"\t_username\"\x10\n" +
"\x0eLogoutResponse*b\n" +
"\x0eLogoutResponse\"\x14\n" +
"\x12GetFeaturesRequest\"x\n" +
"\x13GetFeaturesResponse\x12)\n" +
"\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" +
"\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings*b\n" +
"\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" +
@@ -4881,7 +4973,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xc5\x0f\n" +
"\x05TRACE\x10\a2\x8f\x10\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -4912,7 +5004,8 @@ const file_daemon_proto_rawDesc = "" +
"\rRemoveProfile\x12\x1c.daemon.RemoveProfileRequest\x1a\x1d.daemon.RemoveProfileResponse\"\x00\x12K\n" +
"\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" +
"\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" +
"\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00B\bZ\x06/protob\x06proto3"
"\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" +
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00B\bZ\x06/protob\x06proto3"
var (
file_daemon_proto_rawDescOnce sync.Once
@@ -4927,7 +5020,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 70)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 72)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity
@@ -4999,18 +5092,20 @@ var file_daemon_proto_goTypes = []any{
(*GetActiveProfileResponse)(nil), // 67: daemon.GetActiveProfileResponse
(*LogoutRequest)(nil), // 68: daemon.LogoutRequest
(*LogoutResponse)(nil), // 69: daemon.LogoutResponse
nil, // 70: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 71: daemon.PortInfo.Range
nil, // 72: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 73: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 74: google.protobuf.Timestamp
(*GetFeaturesRequest)(nil), // 70: daemon.GetFeaturesRequest
(*GetFeaturesResponse)(nil), // 71: daemon.GetFeaturesResponse
nil, // 72: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 73: daemon.PortInfo.Range
nil, // 74: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 75: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 76: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
73, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
75, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
74, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
74, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
73, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
76, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
76, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
75, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
@@ -5019,8 +5114,8 @@ var file_daemon_proto_depIdxs = []int32{
21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent
28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
70, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
71, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
72, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
73, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -5031,10 +5126,10 @@ var file_daemon_proto_depIdxs = []int32{
49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
74, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
72, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
76, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
74, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
73, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
75, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
@@ -5064,35 +5159,37 @@ var file_daemon_proto_depIdxs = []int32{
63, // 55: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
66, // 56: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
5, // 58: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
7, // 59: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
9, // 60: daemon.DaemonService.Up:output_type -> daemon.UpResponse
11, // 61: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
13, // 62: daemon.DaemonService.Down:output_type -> daemon.DownResponse
15, // 63: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
24, // 64: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
26, // 65: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
26, // 66: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 67: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
33, // 68: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
35, // 69: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
37, // 70: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
40, // 71: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
42, // 72: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
44, // 73: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
46, // 74: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
50, // 75: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
52, // 76: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
54, // 77: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
56, // 78: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
58, // 79: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
60, // 80: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
62, // 81: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
64, // 82: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
67, // 83: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
69, // 84: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
58, // [58:85] is the sub-list for method output_type
31, // [31:58] is the sub-list for method input_type
70, // 58: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
5, // 59: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
7, // 60: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
9, // 61: daemon.DaemonService.Up:output_type -> daemon.UpResponse
11, // 62: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
13, // 63: daemon.DaemonService.Down:output_type -> daemon.DownResponse
15, // 64: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
24, // 65: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
26, // 66: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
26, // 67: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 68: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
33, // 69: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
35, // 70: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
37, // 71: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
40, // 72: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
42, // 73: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
44, // 74: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
46, // 75: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
50, // 76: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
52, // 77: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
54, // 78: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
56, // 79: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
58, // 80: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
60, // 81: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
62, // 82: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
64, // 83: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
67, // 84: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
69, // 85: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
71, // 86: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
59, // [59:87] is the sub-list for method output_type
31, // [31:59] is the sub-list for method input_type
31, // [31:31] is the sub-list for extension type_name
31, // [31:31] is the sub-list for extension extendee
0, // [0:31] is the sub-list for field type_name
@@ -5120,7 +5217,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 3,
NumMessages: 70,
NumMessages: 72,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -82,6 +82,8 @@ service DaemonService {
// Logout disconnects from the network and deletes the peer from the management server
rpc Logout(LogoutRequest) returns (LogoutResponse) {}
rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {}
}
@@ -624,4 +626,11 @@ message LogoutRequest {
optional string username = 2;
}
message LogoutResponse {}
message LogoutResponse {}
message GetFeaturesRequest{}
message GetFeaturesResponse{
bool disable_profiles = 1;
bool disable_update_settings = 2;
}

View File

@@ -63,6 +63,7 @@ type DaemonServiceClient interface {
GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error)
// Logout disconnects from the network and deletes the peer from the management server
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error)
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
}
type daemonServiceClient struct {
@@ -339,6 +340,15 @@ func (c *daemonServiceClient) Logout(ctx context.Context, in *LogoutRequest, opt
return out, nil
}
func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error) {
out := new(GetFeaturesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetFeatures", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -388,6 +398,7 @@ type DaemonServiceServer interface {
GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error)
// Logout disconnects from the network and deletes the peer from the management server
Logout(context.Context, *LogoutRequest) (*LogoutResponse, error)
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -476,6 +487,9 @@ func (UnimplementedDaemonServiceServer) GetActiveProfile(context.Context, *GetAc
func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented")
}
func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -978,6 +992,24 @@ func _DaemonService_Logout_Handler(srv interface{}, ctx context.Context, dec fun
return interceptor(ctx, in, info, handler)
}
func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetFeaturesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).GetFeatures(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/GetFeatures",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetFeatures(ctx, req.(*GetFeaturesRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1089,6 +1121,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "Logout",
Handler: _DaemonService_Logout_Handler,
},
{
MethodName: "GetFeatures",
Handler: _DaemonService_GetFeatures_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -46,8 +46,9 @@ const (
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7
errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
)
var ErrServiceNotUp = errors.New("service is not up")
@@ -74,8 +75,9 @@ type Server struct {
persistSyncResponse bool
isSessionActive atomic.Bool
profileManager *profilemanager.ServiceManager
profilesDisabled bool
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
}
type oauthAuthFlow struct {
@@ -86,14 +88,15 @@ type oauthAuthFlow struct {
}
// New server instance constructor.
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool) *Server {
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
return &Server{
rootCtx: ctx,
logFile: logFile,
persistSyncResponse: true,
statusRecorder: peer.NewRecorder(""),
profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled,
rootCtx: ctx,
logFile: logFile,
persistSyncResponse: true,
statusRecorder: peer.NewRecorder(""),
profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled,
}
}
@@ -322,8 +325,8 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
if s.checkUpdateSettingsDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
}
profState := profilemanager.ActiveProfileState{
@@ -1330,10 +1333,31 @@ func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfi
}, nil
}
// GetFeatures returns the features supported by the daemon.
func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
}
return features, nil
}
func (s *Server) checkProfilesDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.profilesDisabled {
log.Warn("Profiles are disabled via NB_DISABLE_PROFILES environment variable")
return true
}
return false
}
func (s *Server) checkUpdateSettingsDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.updateSettingsDisabled {
return true
}

View File

@@ -14,6 +14,7 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -24,7 +25,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
daemonProto "github.com/netbirdio/netbird/client/proto"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -33,6 +33,7 @@ import (
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
@@ -94,7 +95,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "debug", "", false)
s := New(ctx, "debug", "", false, false)
s.config = config
@@ -151,7 +152,7 @@ func TestServer_Up(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "console", "", false)
s := New(ctx, "console", "", false, false)
err = s.Start()
require.NoError(t, err)
@@ -227,7 +228,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "console", "", false)
s := New(ctx, "console", "", false, false)
err = s.Start()
require.NoError(t, err)
@@ -302,13 +303,14 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
return nil, "", err
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil {
return nil, "", err

View File

@@ -392,6 +392,16 @@ func (s *serviceClient) updateIcon() {
}
func (s *serviceClient) showSettingsUI() {
// Check if update settings are disabled by daemon
features, err := s.getFeatures()
if err != nil {
log.Errorf("failed to get features from daemon: %v", err)
// Continue with default behavior if features can't be retrieved
} else if features != nil && features.DisableUpdateSettings {
log.Warn("Update settings are disabled by daemon")
return
}
// add settings window UI elements.
s.wSettings = s.app.NewWindow("NetBird Settings")
s.wSettings.SetOnClosed(s.cancel)
@@ -447,6 +457,17 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
},
SubmitText: "Save",
OnSubmit: func() {
// Check if update settings are disabled by daemon
features, err := s.getFeatures()
if err != nil {
log.Errorf("failed to get features from daemon: %v", err)
// Continue with default behavior if features can't be retrieved
} else if features != nil && features.DisableUpdateSettings {
log.Warn("Configuration updates are disabled by daemon")
dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings)
return
}
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
// validate preSharedKey if it added
if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil {
@@ -836,6 +857,20 @@ func (s *serviceClient) onTrayReady() {
s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr)
s.loadSettings()
// Disable settings menu if update settings are disabled by daemon
features, err := s.getFeatures()
if err != nil {
log.Errorf("failed to get features from daemon: %v", err)
// Continue with default behavior if features can't be retrieved
} else {
if features != nil && features.DisableUpdateSettings {
s.setSettingsEnabled(false)
}
if features != nil && features.DisableProfiles {
s.mProfile.setEnabled(false)
}
}
s.exitNodeMu.Lock()
s.mExitNode = systray.AddMenuItem("Exit Node", exitNodeMenuDescr)
s.mExitNode.Disable()
@@ -876,6 +911,10 @@ func (s *serviceClient) onTrayReady() {
if err != nil {
log.Errorf("error while updating status: %v", err)
}
// Check features periodically to handle daemon restarts
s.checkAndUpdateFeatures()
time.Sleep(2 * time.Second)
}
}()
@@ -948,6 +987,59 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
return s.conn, nil
}
// setSettingsEnabled enables or disables the settings menu based on the provided state
func (s *serviceClient) setSettingsEnabled(enabled bool) {
if s.mSettings != nil {
if enabled {
s.mSettings.Enable()
s.mSettings.SetTooltip(settingsMenuDescr)
} else {
s.mSettings.Hide()
s.mSettings.SetTooltip("Settings are disabled by daemon")
}
}
}
// checkAndUpdateFeatures checks the current features and updates the UI accordingly
func (s *serviceClient) checkAndUpdateFeatures() {
features, err := s.getFeatures()
if err != nil {
log.Errorf("failed to get features from daemon: %v", err)
return
}
// Update settings menu based on current features
if features != nil && features.DisableUpdateSettings {
s.setSettingsEnabled(false)
} else {
s.setSettingsEnabled(true)
}
// Update profile menu based on current features
if s.mProfile != nil {
if features != nil && features.DisableProfiles {
s.mProfile.setEnabled(false)
} else {
s.mProfile.setEnabled(true)
}
}
}
// getFeatures from the daemon to determine which features are enabled/disabled.
func (s *serviceClient) getFeatures() (*proto.GetFeaturesResponse, error) {
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
return nil, fmt.Errorf("get client for features: %w", err)
}
features, err := conn.GetFeatures(s.ctx, &proto.GetFeaturesRequest{})
if err != nil {
return nil, fmt.Errorf("get features from daemon: %w", err)
}
return features, nil
}
// getSrvConfig from the service to show it in the settings window.
func (s *serviceClient) getSrvConfig() {
s.managementURL = profilemanager.DefaultManagementURL

View File

@@ -654,6 +654,19 @@ func (p *profileMenu) clear(profiles []Profile) {
}
}
// setEnabled enables or disables the profile menu based on the provided state
func (p *profileMenu) setEnabled(enabled bool) {
if p.profileMenuItem != nil {
if enabled {
p.profileMenuItem.Enable()
p.profileMenuItem.SetTooltip("")
} else {
p.profileMenuItem.Hide()
p.profileMenuItem.SetTooltip("Profiles are disabled by daemon")
}
}
}
func (p *profileMenu) updateMenu() {
// check every second
ticker := time.NewTicker(time.Second)
@@ -662,7 +675,6 @@ func (p *profileMenu) updateMenu() {
for {
select {
case <-ticker.C:
// get profilesList
profiles, err := p.getProfiles()
if err != nil {

2
go.mod
View File

@@ -63,7 +63,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f
github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f h1:YmqNWdRbeVn1lSpkLzIiFHX2cndRuaVYyynx2ibrOtg=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e h1:S85laGfx1UP+nmRF9smP6/TY965kLWz41PbBK1TX8g0=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e/go.mod h1:Jjve0+eUjOLKL3PJtAhjfM2iJ0SxWio5elHqlV1ymP8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -34,6 +34,7 @@ import (
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/types"
@@ -45,7 +46,6 @@ import (
"github.com/netbirdio/netbird/management/server/auth"
nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/metrics"
@@ -220,7 +220,8 @@ var (
return fmt.Errorf("build default manager: %v", err)
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager)
groupsManager := groups.NewManager(store, permissionsManager, accountManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager, groupsManager)
trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
@@ -277,7 +278,6 @@ var (
config.GetAuthAudiences(),
config.HttpConfig.IdpSignKeyRefreshEnabled)
groupsManager := groups.NewManager(store, permissionsManager, accountManager)
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager)
routersManager := routers.NewManager(store, permissionsManager, accountManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)

View File

@@ -21,6 +21,7 @@ type Manager interface {
AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error
AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error)
RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error)
GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error)
}
type managerImpl struct {
@@ -142,6 +143,10 @@ func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transa
return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID)
}
func (m *managerImpl) GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error) {
return m.store.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID)
}
func ToGroupsInfoMap(groups []*types.Group, idCount int) map[string][]api.GroupMinimum {
groupsInfoMap := make(map[string][]api.GroupMinimum, idCount)
groupsChecked := make(map[string]struct{}, len(groups)) // not sure why this is needed (left over from old implementation)
@@ -202,6 +207,10 @@ func (m *mockManager) RemoveResourceFromGroupInTransaction(ctx context.Context,
}, nil
}
func (m *mockManager) GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error) {
return []string{}, nil
}
func NewManagerMock() Manager {
return &mockManager{}
}

View File

@@ -662,7 +662,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
}
}
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings) *proto.SyncResponse {
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
NetworkMap: &proto.NetworkMap{
@@ -674,7 +674,7 @@ func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, nbConfig, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
response.NetbirdConfig = extendedConfig
response.NetworkMap.PeerConfig = response.PeerConfig
@@ -750,7 +750,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
return status.Errorf(codes.Internal, "error handling request")
}
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra)
peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID)
if err != nil {
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
}
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {

View File

@@ -199,6 +199,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
settings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
}
}
@@ -327,6 +328,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
apiSettings.Extra = &api.AccountExtraSettings{
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled,
NetworkTrafficLogsGroups: settings.Extra.FlowGroups,
NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled,
}
}

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -446,6 +447,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
@@ -455,7 +457,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
return nil, nil, "", cleanup, err
}
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})

View File

@@ -23,6 +23,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -216,7 +217,8 @@ func startServer(
t.Fatalf("failed creating an account manager: %v", err)
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
groupsManager := groups.NewManager(str, permissionsManager, accountManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(
context.Background(),
config,

View File

@@ -1275,8 +1275,9 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
}
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting)
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups))
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
@@ -1386,7 +1387,8 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return
}
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings)
peerGroups := account.GetPeerGroups(peerId)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups))
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}

View File

@@ -1164,7 +1164,7 @@ func TestToSyncResponse(t *testing.T) {
}
dnsCache := &DNSConfigCache{}
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil)
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{})
assert.NotNil(t, response)
// assert peer config

View File

@@ -68,6 +68,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
// Once we migrate the peer approval to settings manager this merging is obsolete
if settings.Extra != nil {
settings.Extra.FlowEnabled = extraSettings.FlowEnabled
settings.Extra.FlowGroups = extraSettings.FlowGroups
settings.Extra.FlowPacketCounterEnabled = extraSettings.FlowPacketCounterEnabled
settings.Extra.FlowENCollectionEnabled = extraSettings.FlowENCollectionEnabled
settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled
@@ -93,6 +94,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
}
settings.Extra.FlowEnabled = extraSettings.FlowEnabled
settings.Extra.FlowGroups = extraSettings.FlowGroups
return settings.Extra, nil
}

View File

@@ -11,13 +11,13 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/management/proto"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
)
const defaultDuration = 12 * time.Hour
@@ -39,13 +39,14 @@ type TimeBasedAuthSecretsManager struct {
relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager
settingsManager settings.Manager
groupsManager groups.Manager
turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{}
}
type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager {
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager,
turnCfg: turnCfg,
@@ -53,6 +54,7 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
turnCancelMap: make(map[string]chan struct{}),
relayCancelMap: make(map[string]chan struct{}),
settingsManager: settingsManager,
groupsManager: groupsManager,
}
if turnCfg != nil {
@@ -258,6 +260,11 @@ func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, p
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
}
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, update.NetbirdConfig, extraSettings)
peerGroups, err := m.groupsManager.GetPeerGroupIDs(ctx, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peer groups: %v", err)
}
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, peerGroups, update.NetbirdConfig, extraSettings)
update.NetbirdConfig = extendedConfig
}

View File

@@ -13,9 +13,10 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
@@ -40,13 +41,14 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager)
}, rc, settingsMockManager, groupsManager)
turnCredentials, err := tested.GenerateTurnToken()
require.NoError(t, err)
@@ -91,13 +93,14 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager)
}, rc, settingsMockManager, groupsManager)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -193,13 +196,14 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl,
Secret: secret,
Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true,
}, rc, settingsMockManager)
}, rc, settingsMockManager, groupsManager)
tested.SetupRefresh(context.Background(), "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok {

View File

@@ -2,6 +2,7 @@ package types
import (
"net/netip"
"slices"
"time"
)
@@ -87,21 +88,21 @@ type ExtraSettings struct {
// IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations
IntegratedValidatorGroups []string `gorm:"serializer:json"`
FlowEnabled bool `gorm:"-"`
FlowPacketCounterEnabled bool `gorm:"-"`
FlowENCollectionEnabled bool `gorm:"-"`
FlowDnsCollectionEnabled bool `gorm:"-"`
FlowEnabled bool `gorm:"-"`
FlowGroups []string `gorm:"-"`
FlowPacketCounterEnabled bool `gorm:"-"`
FlowENCollectionEnabled bool `gorm:"-"`
FlowDnsCollectionEnabled bool `gorm:"-"`
}
// Copy copies the ExtraSettings struct
func (e *ExtraSettings) Copy() *ExtraSettings {
var cpGroup []string
return &ExtraSettings{
PeerApprovalEnabled: e.PeerApprovalEnabled,
IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...),
IntegratedValidatorGroups: slices.Clone(e.IntegratedValidatorGroups),
IntegratedValidator: e.IntegratedValidator,
FlowEnabled: e.FlowEnabled,
FlowGroups: slices.Clone(e.FlowGroups),
FlowPacketCounterEnabled: e.FlowPacketCounterEnabled,
FlowENCollectionEnabled: e.FlowENCollectionEnabled,
FlowDnsCollectionEnabled: e.FlowDnsCollectionEnabled,

33
relay/cmd/pprof.go Normal file
View File

@@ -0,0 +1,33 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

View File

@@ -9,6 +9,7 @@ import (
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
@@ -17,8 +18,9 @@ import (
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/util"
)
@@ -34,12 +36,13 @@ type Config struct {
LetsencryptDomains []string
// in case of using Route 53 for DNS challenge the credentials should be provided in the environment variables or
// in the AWS credentials file
LetsencryptAWSRoute53 bool
TlsCertFile string
TlsKeyFile string
AuthSecret string
LogLevel string
LogFile string
LetsencryptAWSRoute53 bool
TlsCertFile string
TlsKeyFile string
AuthSecret string
LogLevel string
LogFile string
HealthcheckListenAddress string
}
func (c Config) Validate() error {
@@ -87,6 +90,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
setFlagsFromEnvVars(rootCmd)
}
@@ -102,6 +106,7 @@ func waitForExitSignal() {
}
func execute(cmd *cobra.Command, args []string) error {
wg := sync.WaitGroup{}
err := cobraConfig.Validate()
if err != nil {
log.Debugf("invalid config: %s", err)
@@ -120,7 +125,9 @@ func execute(cmd *cobra.Command, args []string) error {
return fmt.Errorf("setup metrics: %v", err)
}
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start metrics server: %v", err)
@@ -154,12 +161,31 @@ func execute(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to create relay server: %v", err)
}
log.Infof("server will be available on: %s", srv.InstanceURL())
wg.Add(1)
go func() {
defer wg.Done()
if err := srv.Listen(srvListenerCfg); err != nil {
log.Fatalf("failed to bind server: %s", err)
}
}()
hCfg := healthcheck.Config{
ListenAddress: cobraConfig.HealthcheckListenAddress,
ServiceChecker: srv,
}
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
log.Debugf("failed to create healthcheck server: %v", err)
return fmt.Errorf("failed to create healthcheck server: %v", err)
}
wg.Add(1)
go func() {
defer wg.Done()
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start healthcheck server: %v", err)
}
}()
// it will block until exit signal
waitForExitSignal()
@@ -167,6 +193,10 @@ func execute(cmd *cobra.Command, args []string) error {
defer cancel()
var shutDownErrors error
if err := httpHealthcheck.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err))
}
if err := srv.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
}
@@ -175,6 +205,8 @@ func execute(cmd *cobra.Command, args []string) error {
if err := metricsServer.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
}
wg.Wait()
return shutDownErrors
}

View File

@@ -0,0 +1,195 @@
package healthcheck
import (
"context"
"encoding/json"
"errors"
"net"
"net/http"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
"github.com/netbirdio/netbird/relay/server/listener/quic"
"github.com/netbirdio/netbird/relay/server/listener/ws"
)
const (
statusHealthy = "healthy"
statusUnhealthy = "unhealthy"
path = "/health"
cacheTTL = 3 * time.Second // Cache TTL for health status
)
type ServiceChecker interface {
ListenerProtocols() []protocol.Protocol
ListenAddress() string
}
type HealthStatus struct {
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
Listeners []protocol.Protocol `json:"listeners"`
CertificateValid bool `json:"certificate_valid"`
}
type Config struct {
ListenAddress string
ServiceChecker ServiceChecker
}
type Server struct {
config Config
httpServer *http.Server
cacheMu sync.Mutex
cacheStatus *HealthStatus
}
func NewServer(config Config) (*Server, error) {
mux := http.NewServeMux()
if config.ServiceChecker == nil {
return nil, errors.New("service checker is required")
}
server := &Server{
config: config,
httpServer: &http.Server{
Addr: config.ListenAddress,
Handler: mux,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 15 * time.Second,
},
}
mux.HandleFunc(path, server.handleHealthcheck)
return server, nil
}
func (s *Server) ListenAndServe() error {
log.Infof("starting healthcheck server on: http://%s%s", dialAddress(s.config.ListenAddress), path)
return s.httpServer.ListenAndServe()
}
// Shutdown gracefully shuts down the healthcheck server
func (s *Server) Shutdown(ctx context.Context) error {
log.Info("Shutting down healthcheck server")
return s.httpServer.Shutdown(ctx)
}
func (s *Server) handleHealthcheck(w http.ResponseWriter, _ *http.Request) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var (
status *HealthStatus
ok bool
)
// Cache check
s.cacheMu.Lock()
status = s.cacheStatus
s.cacheMu.Unlock()
if status != nil && time.Since(status.Timestamp) <= cacheTTL {
ok = status.Status == statusHealthy
} else {
status, ok = s.getHealthStatus(ctx)
// Update cache
s.cacheMu.Lock()
s.cacheStatus = status
s.cacheMu.Unlock()
}
w.Header().Set("Content-Type", "application/json")
if ok {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusServiceUnavailable)
}
encoder := json.NewEncoder(w)
if err := encoder.Encode(status); err != nil {
log.Errorf("Failed to encode healthcheck response: %v", err)
}
}
func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) {
healthy := true
status := &HealthStatus{
Timestamp: time.Now(),
Status: statusHealthy,
CertificateValid: true,
}
listeners, ok := s.validateListeners()
if !ok {
status.Status = statusUnhealthy
healthy = false
}
status.Listeners = listeners
if ok := s.validateCertificate(ctx); !ok {
status.Status = statusUnhealthy
status.CertificateValid = false
healthy = false
}
return status, healthy
}
func (s *Server) validateListeners() ([]protocol.Protocol, bool) {
listeners := s.config.ServiceChecker.ListenerProtocols()
if len(listeners) == 0 {
return nil, false
}
return listeners, true
}
func (s *Server) validateCertificate(ctx context.Context) bool {
listenAddress := s.config.ServiceChecker.ListenAddress()
if listenAddress == "" {
log.Warn("listen address is empty")
return false
}
dAddr := dialAddress(listenAddress)
for _, proto := range s.config.ServiceChecker.ListenerProtocols() {
switch proto {
case ws.Proto:
if err := dialWS(ctx, dAddr); err != nil {
log.Errorf("failed to dial WebSocket listener: %v", err)
return false
}
case quic.Proto:
if err := dialQUIC(ctx, dAddr); err != nil {
log.Errorf("failed to dial QUIC listener: %v", err)
return false
}
default:
log.Warnf("unknown protocol for healthcheck: %s", proto)
return false
}
}
return true
}
func dialAddress(listenAddress string) string {
host, port, err := net.SplitHostPort(listenAddress)
if err != nil {
return listenAddress // fallback, might be invalid for dialing
}
if host == "" || host == "::" || host == "0.0.0.0" {
host = "0.0.0.0"
}
return net.JoinHostPort(host, port)
}

31
relay/healthcheck/quic.go Normal file
View File

@@ -0,0 +1,31 @@
package healthcheck
import (
"context"
"crypto/tls"
"fmt"
"time"
"github.com/quic-go/quic-go"
tlsnb "github.com/netbirdio/netbird/shared/relay/tls"
)
func dialQUIC(ctx context.Context, address string) error {
tlsConfig := &tls.Config{
InsecureSkipVerify: false, // Keep certificate validation enabled
NextProtos: []string{tlsnb.NBalpn},
}
conn, err := quic.DialAddr(ctx, address, tlsConfig, &quic.Config{
MaxIdleTimeout: 30 * time.Second,
KeepAlivePeriod: 10 * time.Second,
EnableDatagrams: true,
})
if err != nil {
return fmt.Errorf("failed to connect to QUIC server: %w", err)
}
_ = conn.CloseWithError(0, "availability check complete")
return nil
}

28
relay/healthcheck/ws.go Normal file
View File

@@ -0,0 +1,28 @@
package healthcheck
import (
"context"
"fmt"
"github.com/coder/websocket"
"github.com/netbirdio/netbird/shared/relay"
)
func dialWS(ctx context.Context, address string) error {
url := fmt.Sprintf("wss://%s%s", address, relay.WebSocketURLPath)
conn, resp, err := websocket.Dial(ctx, url, nil)
if resp != nil {
defer func() {
_ = resp.Body.Close()
}()
}
if err != nil {
return fmt.Errorf("failed to connect to websocket: %w", err)
}
_ = conn.Close(websocket.StatusNormalClosure, "availability check complete")
return nil
}

View File

@@ -0,0 +1,3 @@
package protocol
type Protocol string

View File

@@ -3,9 +3,12 @@ package listener
import (
"context"
"net"
"github.com/netbirdio/netbird/relay/protocol"
)
type Listener interface {
Listen(func(conn net.Conn)) error
Shutdown(ctx context.Context) error
Protocol() protocol.Protocol
}

View File

@@ -9,8 +9,12 @@ import (
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
)
const Proto protocol.Protocol = "quic"
type Listener struct {
// Address is the address to listen on
Address string
@@ -50,6 +54,10 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
}
}
func (l *Listener) Protocol() protocol.Protocol {
return Proto
}
func (l *Listener) Shutdown(ctx context.Context) error {
if l.listener == nil {
return nil

View File

@@ -11,11 +11,14 @@ import (
"github.com/coder/websocket"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
"github.com/netbirdio/netbird/shared/relay"
)
// URLPath is the path for the websocket connection.
const URLPath = relay.WebSocketURLPath
const (
Proto protocol.Protocol = "ws"
URLPath = relay.WebSocketURLPath
)
type Listener struct {
// Address is the address to listen on.
@@ -51,6 +54,10 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
return err
}
func (l *Listener) Protocol() protocol.Protocol {
return Proto
}
func (l *Listener) Shutdown(ctx context.Context) error {
if l.server == nil {
return nil

View File

@@ -6,12 +6,14 @@ import (
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/relay/protocol"
"github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/quic"
"github.com/netbirdio/netbird/relay/server/listener/ws"
quictls "github.com/netbirdio/netbird/shared/relay/tls"
log "github.com/sirupsen/logrus"
)
// ListenerConfig is the configuration for the listener.
@@ -26,8 +28,11 @@ type ListenerConfig struct {
// It is the gate between the WebSocket listener and the Relay server logic.
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
type Server struct {
relay *Relay
listeners []listener.Listener
listenAddr string
relay *Relay
listeners []listener.Listener
listenerMux sync.Mutex
}
// NewServer creates and returns a new relay server instance.
@@ -57,10 +62,14 @@ func NewServer(config Config) (*Server, error) {
// Listen starts the relay server.
func (r *Server) Listen(cfg ListenerConfig) error {
r.listenAddr = cfg.Address
wSListener := &ws.Listener{
Address: cfg.Address,
TLSConfig: cfg.TLSConfig,
}
r.listenerMux.Lock()
r.listeners = append(r.listeners, wSListener)
tlsConfigQUIC, err := quictls.ServerQUICTLSConfig(cfg.TLSConfig)
@@ -85,6 +94,8 @@ func (r *Server) Listen(cfg ListenerConfig) error {
}(l)
}
r.listenerMux.Unlock()
wg.Wait()
close(errChan)
var multiErr *multierror.Error
@@ -100,12 +111,15 @@ func (r *Server) Listen(cfg ListenerConfig) error {
func (r *Server) Shutdown(ctx context.Context) error {
r.relay.Shutdown(ctx)
r.listenerMux.Lock()
var multiErr *multierror.Error
for _, l := range r.listeners {
if err := l.Shutdown(ctx); err != nil {
multiErr = multierror.Append(multiErr, err)
}
}
r.listeners = r.listeners[:0]
r.listenerMux.Unlock()
return nberrors.FormatErrorOrNil(multiErr)
}
@@ -113,3 +127,18 @@ func (r *Server) Shutdown(ctx context.Context) error {
func (r *Server) InstanceURL() string {
return r.relay.instanceURL
}
func (r *Server) ListenerProtocols() []protocol.Protocol {
result := make([]protocol.Protocol, 0)
r.listenerMux.Lock()
for _, l := range r.listeners {
result = append(result, l.Protocol())
}
r.listenerMux.Unlock()
return result
}
func (r *Server) ListenAddress() string {
return r.listenAddr
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -111,7 +112,9 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
groupsManager := groups.NewManagerMock()
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)

View File

@@ -162,6 +162,12 @@ components:
description: Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored.
type: boolean
example: true
network_traffic_logs_groups:
description: Limits traffic logging to these groups. If unset all peers are enabled.
type: array
items:
type: string
example: ch8i4ug6lnn4g9hqv7m0
network_traffic_packet_counter_enabled:
description: Enables or disables network traffic packet counter. If enabled, network packets and their size will be counted and reported. (This can have an slight impact on performance)
type: boolean
@@ -169,6 +175,7 @@ components:
required:
- peer_approval_enabled
- network_traffic_logs_enabled
- network_traffic_logs_groups
- network_traffic_packet_counter_enabled
AccountRequest:
type: object

View File

@@ -260,6 +260,9 @@ type AccountExtraSettings struct {
// NetworkTrafficLogsEnabled Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored.
NetworkTrafficLogsEnabled bool `json:"network_traffic_logs_enabled"`
// NetworkTrafficLogsGroups Limits traffic logging to these groups. If unset all peers are enabled.
NetworkTrafficLogsGroups []string `json:"network_traffic_logs_groups"`
// NetworkTrafficPacketCounterEnabled Enables or disables network traffic packet counter. If enabled, network packets and their size will be counted and reported. (This can have an slight impact on performance)
NetworkTrafficPacketCounterEnabled bool `json:"network_traffic_packet_counter_enabled"`

View File

@@ -1,3 +1,3 @@
package tls
const nbalpn = "nb-quic"
const NBalpn = "nb-quic"

View File

@@ -20,7 +20,7 @@ func ClientQUICTLSConfig() *tls.Config {
return &tls.Config{
InsecureSkipVerify: true, // Debug mode allows insecure connections
NextProtos: []string{nbalpn}, // Ensure this matches the server's ALPN
NextProtos: []string{NBalpn}, // Ensure this matches the server's ALPN
RootCAs: certPool,
}
}

View File

@@ -19,7 +19,7 @@ func ClientQUICTLSConfig() *tls.Config {
}
return &tls.Config{
NextProtos: []string{nbalpn},
NextProtos: []string{NBalpn},
RootCAs: certPool,
}
}

View File

@@ -23,7 +23,7 @@ func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
}
cfg := originTLSCfg.Clone()
cfg.NextProtos = []string{nbalpn}
cfg.NextProtos = []string{NBalpn}
return cfg, nil
}
@@ -74,6 +74,6 @@ func generateTestTLSConfig() (*tls.Config, error) {
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{nbalpn},
NextProtos: []string{NBalpn},
}, nil
}

View File

@@ -12,6 +12,6 @@ func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
return nil, fmt.Errorf("valid TLS config is required for QUIC listener")
}
cfg := originTLSCfg.Clone()
cfg.NextProtos = []string{nbalpn}
cfg.NextProtos = []string{NBalpn}
return cfg, nil
}