mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 07:46:38 +00:00
Compare commits
7 Commits
handle-exi
...
fix/remove
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68996a1566 | ||
|
|
a03a6eb6f3 | ||
|
|
a4e8647aef | ||
|
|
160b811e21 | ||
|
|
5e607cf4e9 | ||
|
|
0fdb944058 | ||
|
|
ccbabd9e2a |
@@ -11,6 +11,7 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -97,6 +98,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
||||
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
settingsMockManager.EXPECT().
|
||||
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||
@@ -108,7 +110,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -176,4 +176,3 @@ nameserver 192.168.0.1
|
||||
t.Errorf("unexpected resolv.conf content: %v", cfg)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -64,9 +64,10 @@ const (
|
||||
)
|
||||
|
||||
type registryConfigurator struct {
|
||||
guid string
|
||||
routingAll bool
|
||||
gpo bool
|
||||
guid string
|
||||
routingAll bool
|
||||
gpo bool
|
||||
nrptEntryCount int
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
@@ -177,7 +178,11 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||
}
|
||||
|
||||
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil {
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
GPO: r.gpo,
|
||||
NRPTEntryCount: r.nrptEntryCount,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
|
||||
@@ -193,13 +198,24 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
||||
}
|
||||
|
||||
if len(matchDomains) != 0 {
|
||||
if err := r.addDNSMatchPolicy(matchDomains, config.ServerIP); err != nil {
|
||||
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns match policy: %w", err)
|
||||
}
|
||||
r.nrptEntryCount = count
|
||||
} else {
|
||||
if err := r.removeDNSMatchPolicies(); err != nil {
|
||||
return fmt.Errorf("remove dns match policies: %w", err)
|
||||
}
|
||||
r.nrptEntryCount = 0
|
||||
}
|
||||
|
||||
if err := stateManager.UpdateState(&ShutdownState{
|
||||
Guid: r.guid,
|
||||
GPO: r.gpo,
|
||||
NRPTEntryCount: r.nrptEntryCount,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update shutdown state: %s", err)
|
||||
}
|
||||
|
||||
if err := r.updateSearchDomains(searchDomains); err != nil {
|
||||
@@ -220,28 +236,34 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) error {
|
||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
|
||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||
if r.gpo {
|
||||
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, domains, ip); err != nil {
|
||||
return fmt.Errorf("configure GPO DNS policy: %w", err)
|
||||
for i, domain := range domains {
|
||||
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
if r.gpo {
|
||||
policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
}
|
||||
|
||||
singleDomain := []string{domain}
|
||||
|
||||
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
|
||||
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
|
||||
}
|
||||
|
||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
||||
}
|
||||
|
||||
if r.gpo {
|
||||
if err := refreshGroupPolicy(); err != nil {
|
||||
log.Warnf("failed to refresh group policy: %v", err)
|
||||
}
|
||||
} else {
|
||||
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
|
||||
return fmt.Errorf("configure local DNS policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("added %d match domains. Domain list: %s", len(domains), domains)
|
||||
return nil
|
||||
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
|
||||
return len(domains), nil
|
||||
}
|
||||
|
||||
// configureDNSPolicy handles the actual configuration of a DNS policy at the specified path
|
||||
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
|
||||
return fmt.Errorf("remove existing dns policy: %w", err)
|
||||
@@ -374,12 +396,25 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
||||
|
||||
func (r *registryConfigurator) removeDNSMatchPolicies() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
// Try to remove the base entries (for backward compatibility)
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local registry key: %w", err))
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
|
||||
}
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
|
||||
}
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO registry key: %w", err))
|
||||
for i := 0; i < r.nrptEntryCount; i++ {
|
||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
||||
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
|
||||
}
|
||||
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := refreshGroupPolicy(); err != nil {
|
||||
|
||||
@@ -695,6 +695,12 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
||||
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
|
||||
continue
|
||||
}
|
||||
|
||||
if ns.IP == s.service.RuntimeIP() {
|
||||
log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP)
|
||||
continue
|
||||
}
|
||||
|
||||
handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort())
|
||||
}
|
||||
|
||||
|
||||
@@ -2056,3 +2056,124 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
|
||||
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||
}
|
||||
|
||||
func TestDNSLoopPrevention(t *testing.T) {
|
||||
wgInterface := &mocWGIface{}
|
||||
service := NewServiceViaMemory(wgInterface)
|
||||
dnsServerIP := service.RuntimeIP()
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: wgInterface,
|
||||
service: service,
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
nsGroups []*nbdns.NameServerGroup
|
||||
expectedHandlers int
|
||||
expectedServers []netip.Addr
|
||||
shouldFilterOwnIP bool
|
||||
}{
|
||||
{
|
||||
name: "FilterOwnDNSServerIP",
|
||||
nsGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Primary: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
Domains: []string{},
|
||||
},
|
||||
},
|
||||
expectedHandlers: 1,
|
||||
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
|
||||
shouldFilterOwnIP: true,
|
||||
},
|
||||
{
|
||||
name: "AllServersFiltered",
|
||||
nsGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Primary: false,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
Domains: []string{"example.com"},
|
||||
},
|
||||
},
|
||||
expectedHandlers: 0,
|
||||
expectedServers: []netip.Addr{},
|
||||
shouldFilterOwnIP: true,
|
||||
},
|
||||
{
|
||||
name: "MixedServersWithOwnIP",
|
||||
nsGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Primary: false,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53}, // duplicate
|
||||
},
|
||||
Domains: []string{"test.com"},
|
||||
},
|
||||
},
|
||||
expectedHandlers: 1,
|
||||
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
|
||||
shouldFilterOwnIP: true,
|
||||
},
|
||||
{
|
||||
name: "NoOwnIPInList",
|
||||
nsGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
Primary: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
Domains: []string{},
|
||||
},
|
||||
},
|
||||
expectedHandlers: 1,
|
||||
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
|
||||
shouldFilterOwnIP: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
muxUpdates, err := server.buildUpstreamHandlerUpdate(tt.nsGroups)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, muxUpdates, tt.expectedHandlers)
|
||||
|
||||
if tt.expectedHandlers > 0 {
|
||||
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
|
||||
|
||||
if tt.shouldFilterOwnIP {
|
||||
for _, upstream := range handler.upstreamServers {
|
||||
assert.NotEqual(t, dnsServerIP, upstream.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
for _, expected := range tt.expectedServers {
|
||||
found := false
|
||||
for _, upstream := range handler.upstreamServers {
|
||||
if upstream.Addr() == expected {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Expected server %s not found", expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,9 @@ import (
|
||||
)
|
||||
|
||||
type ShutdownState struct {
|
||||
Guid string
|
||||
GPO bool
|
||||
Guid string
|
||||
GPO bool
|
||||
NRPTEntryCount int
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
@@ -15,8 +16,9 @@ func (s *ShutdownState) Name() string {
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
manager := ®istryConfigurator{
|
||||
guid: s.Guid,
|
||||
gpo: s.GPO,
|
||||
guid: s.Guid,
|
||||
gpo: s.GPO,
|
||||
nrptEntryCount: s.NRPTEntryCount,
|
||||
}
|
||||
|
||||
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||
|
||||
@@ -165,7 +165,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
|
||||
defer cancel()
|
||||
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
|
||||
if err != nil {
|
||||
f.handleDNSError(w, query, resp, domain, err)
|
||||
f.handleDNSError(ctx, w, question, resp, domain, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -244,20 +244,57 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
|
||||
}
|
||||
}
|
||||
|
||||
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
|
||||
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
|
||||
//
|
||||
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
|
||||
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
|
||||
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
|
||||
// only handles A/AAAA queries and returns NOTIMP for other types.
|
||||
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
|
||||
// Try querying for a different record type to see if the domain exists
|
||||
// If the original query was for AAAA, try A. If it was for A, try AAAA.
|
||||
// This helps distinguish between NXDOMAIN and NODATA.
|
||||
var alternativeNetwork string
|
||||
switch originalQtype {
|
||||
case dns.TypeAAAA:
|
||||
alternativeNetwork = "ip4"
|
||||
case dns.TypeA:
|
||||
alternativeNetwork = "ip6"
|
||||
default:
|
||||
resp.Rcode = dns.RcodeNameError
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
|
||||
var dnsErr *net.DNSError
|
||||
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
|
||||
// Alternative query also returned not found - domain truly doesn't exist
|
||||
resp.Rcode = dns.RcodeNameError
|
||||
return
|
||||
}
|
||||
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
return
|
||||
}
|
||||
|
||||
// Alternative query succeeded - domain exists but has no records of this type
|
||||
resp.Rcode = dns.RcodeSuccess
|
||||
}
|
||||
|
||||
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) {
|
||||
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
|
||||
var dnsErr *net.DNSError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &dnsErr):
|
||||
resp.Rcode = dns.RcodeServerFailure
|
||||
if dnsErr.IsNotFound {
|
||||
// Pass through NXDOMAIN
|
||||
resp.Rcode = dns.RcodeNameError
|
||||
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
|
||||
}
|
||||
|
||||
if dnsErr.Server != "" {
|
||||
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err)
|
||||
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
|
||||
} else {
|
||||
log.Warnf(errResolveFailed, domain, err)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package dnsfwd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -16,8 +17,8 @@ import (
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
func Test_getMatchingEntries(t *testing.T) {
|
||||
@@ -708,6 +709,131 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
assert.Len(t, matches, 3, "Should match 3 patterns")
|
||||
}
|
||||
|
||||
// TestDNSForwarder_NodataVsNxdomain tests that the forwarder correctly distinguishes
|
||||
// between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of that type)
|
||||
func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
set := firewall.NewDomainSet([]domain.Domain{d})
|
||||
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res", Set: set}}
|
||||
forwarder.UpdateDomains(entries)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryType uint16
|
||||
setupMocks func()
|
||||
expectedCode int
|
||||
expectNoAnswer bool // true if we expect NOERROR with empty answer (NODATA case)
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "domain exists but no AAAA records (NODATA)",
|
||||
queryType: dns.TypeAAAA,
|
||||
setupMocks: func() {
|
||||
// First query for AAAA returns not found
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
|
||||
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
|
||||
// Check query for A records succeeds (domain exists)
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
|
||||
},
|
||||
expectedCode: dns.RcodeSuccess,
|
||||
expectNoAnswer: true,
|
||||
description: "Should return NOERROR when domain exists but has no records of requested type",
|
||||
},
|
||||
{
|
||||
name: "domain exists but no A records (NODATA)",
|
||||
queryType: dns.TypeA,
|
||||
setupMocks: func() {
|
||||
// First query for A returns not found
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
|
||||
// Check query for AAAA records succeeds (domain exists)
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
|
||||
Return([]netip.Addr{netip.MustParseAddr("2001:db8::1")}, nil).Once()
|
||||
},
|
||||
expectedCode: dns.RcodeSuccess,
|
||||
expectNoAnswer: true,
|
||||
description: "Should return NOERROR when domain exists but has no A records",
|
||||
},
|
||||
{
|
||||
name: "domain doesn't exist (NXDOMAIN)",
|
||||
queryType: dns.TypeA,
|
||||
setupMocks: func() {
|
||||
// First query for A returns not found
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
|
||||
// Check query for AAAA also returns not found (domain doesn't exist)
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
|
||||
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
|
||||
},
|
||||
expectedCode: dns.RcodeNameError,
|
||||
expectNoAnswer: true,
|
||||
description: "Should return NXDOMAIN when domain doesn't exist at all",
|
||||
},
|
||||
{
|
||||
name: "domain exists with records (normal success)",
|
||||
queryType: dns.TypeA,
|
||||
setupMocks: func() {
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
|
||||
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
|
||||
// Expect firewall update for successful resolution
|
||||
expectedPrefix := netip.PrefixFrom(netip.MustParseAddr("1.2.3.4"), 32)
|
||||
mockFirewall.On("UpdateSet", set, []netip.Prefix{expectedPrefix}).Return(nil).Once()
|
||||
},
|
||||
expectedCode: dns.RcodeSuccess,
|
||||
expectNoAnswer: false,
|
||||
description: "Should return NOERROR with answer when records exist",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset mock expectations
|
||||
mockResolver.ExpectedCalls = nil
|
||||
mockResolver.Calls = nil
|
||||
mockFirewall.ExpectedCalls = nil
|
||||
mockFirewall.Calls = nil
|
||||
|
||||
tt.setupMocks()
|
||||
|
||||
query := &dns.Msg{}
|
||||
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||
|
||||
var writtenResp *dns.Msg
|
||||
mockWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
writtenResp = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resp := forwarder.handleDNSQuery(mockWriter, query)
|
||||
|
||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
||||
if resp != nil && writtenResp == nil {
|
||||
writtenResp = resp
|
||||
}
|
||||
|
||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
||||
|
||||
if tt.expectNoAnswer {
|
||||
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
||||
}
|
||||
|
||||
mockResolver.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||
// Test handling of malformed query with no questions
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/bind"
|
||||
@@ -1564,13 +1565,14 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
AnyTimes()
|
||||
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -302,13 +303,14 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"fyne.io/systray"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -232,19 +231,3 @@ func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string)
|
||||
|
||||
log.Printf("command '%s %s' completed successfully", command, arg)
|
||||
}
|
||||
|
||||
func (h *eventHandler) logout(ctx context.Context) error {
|
||||
client, err := h.client.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service client: %w", err)
|
||||
}
|
||||
|
||||
_, err = client.Logout(ctx, &proto.LogoutRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout failed: %w", err)
|
||||
}
|
||||
|
||||
h.client.getSrvConfig()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -40,13 +40,12 @@ func (s *serviceClient) showProfilesUI() {
|
||||
list := widget.NewList(
|
||||
func() int { return len(profiles) },
|
||||
func() fyne.CanvasObject {
|
||||
// Each item: Selected indicator, Name, spacer, Select, Logout & Remove buttons
|
||||
// Each item: Selected indicator, Name, spacer, Select & Remove buttons
|
||||
return container.NewHBox(
|
||||
widget.NewLabel(""), // indicator
|
||||
widget.NewLabel(""), // profile name
|
||||
layout.NewSpacer(),
|
||||
widget.NewButton("Select", nil),
|
||||
widget.NewButton("Deregister", nil),
|
||||
widget.NewButton("Remove", nil),
|
||||
)
|
||||
},
|
||||
@@ -56,8 +55,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
indicator := row.Objects[0].(*widget.Label)
|
||||
nameLabel := row.Objects[1].(*widget.Label)
|
||||
selectBtn := row.Objects[3].(*widget.Button)
|
||||
logoutBtn := row.Objects[4].(*widget.Button)
|
||||
removeBtn := row.Objects[5].(*widget.Button)
|
||||
removeBtn := row.Objects[4].(*widget.Button)
|
||||
|
||||
profile := profiles[i]
|
||||
// Show a checkmark if selected
|
||||
@@ -127,12 +125,6 @@ func (s *serviceClient) showProfilesUI() {
|
||||
)
|
||||
}
|
||||
|
||||
logoutBtn.Show()
|
||||
logoutBtn.SetText("Deregister")
|
||||
logoutBtn.OnTapped = func() {
|
||||
s.handleProfileLogout(profile.Name, refresh)
|
||||
}
|
||||
|
||||
// Remove profile
|
||||
removeBtn.SetText("Remove")
|
||||
removeBtn.OnTapped = func() {
|
||||
@@ -332,52 +324,6 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
|
||||
dialog.ShowConfirm(
|
||||
"Deregister",
|
||||
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
|
||||
func(confirm bool) {
|
||||
if !confirm {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get service client: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("failed to connect to service"), s.wProfiles)
|
||||
return
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get current user: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("failed to get current user"), s.wProfiles)
|
||||
return
|
||||
}
|
||||
|
||||
username := currUser.Username
|
||||
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
|
||||
ProfileName: &profileName,
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("logout failed: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("deregister failed"), s.wProfiles)
|
||||
return
|
||||
}
|
||||
|
||||
dialog.ShowInformation(
|
||||
"Deregistered",
|
||||
fmt.Sprintf("Successfully deregistered from '%s'", profileName),
|
||||
s.wProfiles,
|
||||
)
|
||||
|
||||
refreshCallback()
|
||||
},
|
||||
s.wProfiles,
|
||||
)
|
||||
}
|
||||
|
||||
type subItem struct {
|
||||
*systray.MenuItem
|
||||
ctx context.Context
|
||||
@@ -393,7 +339,6 @@ type profileMenu struct {
|
||||
emailMenuItem *systray.MenuItem
|
||||
profileSubItems []*subItem
|
||||
manageProfilesSubItem *subItem
|
||||
logoutSubItem *subItem
|
||||
profilesState []Profile
|
||||
downClickCallback func() error
|
||||
upClickCallback func() error
|
||||
@@ -600,30 +545,6 @@ func (p *profileMenu) refresh() {
|
||||
}
|
||||
}()
|
||||
|
||||
// Add Logout menu item
|
||||
ctx2, cancel2 := context.WithCancel(context.Background())
|
||||
logoutItem := p.profileMenuItem.AddSubMenuItem("Deregister", "")
|
||||
p.logoutSubItem = &subItem{logoutItem, ctx2, cancel2}
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx2.Done():
|
||||
return
|
||||
case _, ok := <-logoutItem.ClickedCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := p.eventHandler.logout(p.ctx); err != nil {
|
||||
log.Errorf("logout failed: %v", err)
|
||||
p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister"))
|
||||
} else {
|
||||
p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username {
|
||||
p.profileMenuItem.SetTitle(activeProf.ProfileName)
|
||||
} else {
|
||||
@@ -646,12 +567,6 @@ func (p *profileMenu) clear(profiles []Profile) {
|
||||
p.manageProfilesSubItem.cancel()
|
||||
p.manageProfilesSubItem = nil
|
||||
}
|
||||
|
||||
if p.logoutSubItem != nil {
|
||||
p.logoutSubItem.Remove()
|
||||
p.logoutSubItem.cancel()
|
||||
p.logoutSubItem = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *profileMenu) updateMenu() {
|
||||
|
||||
2
go.mod
2
go.mod
@@ -63,7 +63,7 @@ require (
|
||||
github.com/miekg/dns v1.1.59
|
||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||
github.com/nadoo/ipset v0.5.0
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
|
||||
github.com/okta/okta-sdk-golang/v2 v2.18.0
|
||||
github.com/oschwald/maxminddb-golang v1.12.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
|
||||
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
|
||||
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f h1:YmqNWdRbeVn1lSpkLzIiFHX2cndRuaVYyynx2ibrOtg=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e h1:S85laGfx1UP+nmRF9smP6/TY965kLWz41PbBK1TX8g0=
|
||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e/go.mod h1:Jjve0+eUjOLKL3PJtAhjfM2iJ0SxWio5elHqlV1ymP8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
|
||||
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
|
||||
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
|
||||
|
||||
@@ -34,6 +34,7 @@ import (
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
@@ -45,7 +46,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/metrics"
|
||||
@@ -220,7 +220,8 @@ var (
|
||||
return fmt.Errorf("build default manager: %v", err)
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager)
|
||||
groupsManager := groups.NewManager(store, permissionsManager, accountManager)
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager, groupsManager)
|
||||
|
||||
trustedPeers := config.ReverseProxy.TrustedPeers
|
||||
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
||||
@@ -277,7 +278,6 @@ var (
|
||||
config.GetAuthAudiences(),
|
||||
config.HttpConfig.IdpSignKeyRefreshEnabled)
|
||||
|
||||
groupsManager := groups.NewManager(store, permissionsManager, accountManager)
|
||||
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager)
|
||||
routersManager := routers.NewManager(store, permissionsManager, accountManager)
|
||||
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
|
||||
|
||||
@@ -21,6 +21,7 @@ type Manager interface {
|
||||
AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error
|
||||
AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error)
|
||||
RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error)
|
||||
GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
@@ -142,6 +143,10 @@ func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transa
|
||||
return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error) {
|
||||
return m.store.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID)
|
||||
}
|
||||
|
||||
func ToGroupsInfoMap(groups []*types.Group, idCount int) map[string][]api.GroupMinimum {
|
||||
groupsInfoMap := make(map[string][]api.GroupMinimum, idCount)
|
||||
groupsChecked := make(map[string]struct{}, len(groups)) // not sure why this is needed (left over from old implementation)
|
||||
@@ -202,6 +207,10 @@ func (m *mockManager) RemoveResourceFromGroupInTransaction(ctx context.Context,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockManager) GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error) {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
func NewManagerMock() Manager {
|
||||
return &mockManager{}
|
||||
}
|
||||
|
||||
@@ -662,7 +662,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
||||
}
|
||||
}
|
||||
|
||||
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings) *proto.SyncResponse {
|
||||
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
@@ -674,7 +674,7 @@ func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, nbConfig, extraSettings)
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
response.NetbirdConfig = extendedConfig
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
@@ -750,7 +750,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra)
|
||||
peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
|
||||
}
|
||||
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
|
||||
@@ -199,6 +199,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
settings.Extra = &types.ExtraSettings{
|
||||
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
|
||||
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
|
||||
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
|
||||
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
|
||||
}
|
||||
}
|
||||
@@ -327,6 +328,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
apiSettings.Extra = &api.AccountExtraSettings{
|
||||
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
|
||||
NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled,
|
||||
NetworkTrafficLogsGroups: settings.Extra.FlowGroups,
|
||||
NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/formatter/hook"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -446,6 +447,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||
@@ -455,7 +457,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
|
||||
return nil, nil, "", cleanup, err
|
||||
}
|
||||
|
||||
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
|
||||
ephemeralMgr := NewEphemeralManager(store, accountManager)
|
||||
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -216,7 +217,8 @@ func startServer(
|
||||
t.Fatalf("failed creating an account manager: %v", err)
|
||||
}
|
||||
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
groupsManager := groups.NewManager(str, permissionsManager, accountManager)
|
||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := server.NewServer(
|
||||
context.Background(),
|
||||
config,
|
||||
|
||||
@@ -1275,8 +1275,9 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
}
|
||||
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
|
||||
|
||||
peerGroups := account.GetPeerGroups(p.ID)
|
||||
start = time.Now()
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting)
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups))
|
||||
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
@@ -1386,7 +1387,8 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
return
|
||||
}
|
||||
|
||||
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings)
|
||||
peerGroups := account.GetPeerGroups(peerId)
|
||||
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups))
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
}
|
||||
|
||||
|
||||
@@ -1164,7 +1164,7 @@ func TestToSyncResponse(t *testing.T) {
|
||||
}
|
||||
dnsCache := &DNSConfigCache{}
|
||||
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
|
||||
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil)
|
||||
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{})
|
||||
|
||||
assert.NotNil(t, response)
|
||||
// assert peer config
|
||||
|
||||
@@ -68,6 +68,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
|
||||
// Once we migrate the peer approval to settings manager this merging is obsolete
|
||||
if settings.Extra != nil {
|
||||
settings.Extra.FlowEnabled = extraSettings.FlowEnabled
|
||||
settings.Extra.FlowGroups = extraSettings.FlowGroups
|
||||
settings.Extra.FlowPacketCounterEnabled = extraSettings.FlowPacketCounterEnabled
|
||||
settings.Extra.FlowENCollectionEnabled = extraSettings.FlowENCollectionEnabled
|
||||
settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled
|
||||
@@ -93,6 +94,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
|
||||
}
|
||||
|
||||
settings.Extra.FlowEnabled = extraSettings.FlowEnabled
|
||||
settings.Extra.FlowGroups = extraSettings.FlowGroups
|
||||
|
||||
return settings.Extra, nil
|
||||
}
|
||||
|
||||
@@ -11,13 +11,13 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
)
|
||||
|
||||
const defaultDuration = 12 * time.Hour
|
||||
@@ -39,13 +39,14 @@ type TimeBasedAuthSecretsManager struct {
|
||||
relayHmacToken *authv2.Generator
|
||||
updateManager *PeersUpdateManager
|
||||
settingsManager settings.Manager
|
||||
groupsManager groups.Manager
|
||||
turnCancelMap map[string]chan struct{}
|
||||
relayCancelMap map[string]chan struct{}
|
||||
}
|
||||
|
||||
type Token auth.Token
|
||||
|
||||
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager {
|
||||
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
|
||||
mgr := &TimeBasedAuthSecretsManager{
|
||||
updateManager: updateManager,
|
||||
turnCfg: turnCfg,
|
||||
@@ -53,6 +54,7 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
|
||||
turnCancelMap: make(map[string]chan struct{}),
|
||||
relayCancelMap: make(map[string]chan struct{}),
|
||||
settingsManager: settingsManager,
|
||||
groupsManager: groupsManager,
|
||||
}
|
||||
|
||||
if turnCfg != nil {
|
||||
@@ -258,6 +260,11 @@ func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, p
|
||||
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
|
||||
}
|
||||
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, update.NetbirdConfig, extraSettings)
|
||||
peerGroups, err := m.groupsManager.GetPeerGroupIDs(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peer groups: %v", err)
|
||||
}
|
||||
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, peerGroups, update.NetbirdConfig, extraSettings)
|
||||
update.NetbirdConfig = extendedConfig
|
||||
}
|
||||
|
||||
@@ -13,9 +13,10 @@ import (
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -40,13 +41,14 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*types.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager)
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
|
||||
turnCredentials, err := tested.GenerateTurnToken()
|
||||
require.NoError(t, err)
|
||||
@@ -91,13 +93,14 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*types.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager)
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@@ -193,13 +196,14 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*types.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager)
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
|
||||
tested.SetupRefresh(context.Background(), "someAccountID", peer)
|
||||
if _, ok := tested.turnCancelMap[peer]; !ok {
|
||||
|
||||
@@ -2,6 +2,7 @@ package types
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -87,21 +88,21 @@ type ExtraSettings struct {
|
||||
// IntegratedValidatorGroups list of group IDs to be used with integrated approval configurations
|
||||
IntegratedValidatorGroups []string `gorm:"serializer:json"`
|
||||
|
||||
FlowEnabled bool `gorm:"-"`
|
||||
FlowPacketCounterEnabled bool `gorm:"-"`
|
||||
FlowENCollectionEnabled bool `gorm:"-"`
|
||||
FlowDnsCollectionEnabled bool `gorm:"-"`
|
||||
FlowEnabled bool `gorm:"-"`
|
||||
FlowGroups []string `gorm:"-"`
|
||||
FlowPacketCounterEnabled bool `gorm:"-"`
|
||||
FlowENCollectionEnabled bool `gorm:"-"`
|
||||
FlowDnsCollectionEnabled bool `gorm:"-"`
|
||||
}
|
||||
|
||||
// Copy copies the ExtraSettings struct
|
||||
func (e *ExtraSettings) Copy() *ExtraSettings {
|
||||
var cpGroup []string
|
||||
|
||||
return &ExtraSettings{
|
||||
PeerApprovalEnabled: e.PeerApprovalEnabled,
|
||||
IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...),
|
||||
IntegratedValidatorGroups: slices.Clone(e.IntegratedValidatorGroups),
|
||||
IntegratedValidator: e.IntegratedValidator,
|
||||
FlowEnabled: e.FlowEnabled,
|
||||
FlowGroups: slices.Clone(e.FlowGroups),
|
||||
FlowPacketCounterEnabled: e.FlowPacketCounterEnabled,
|
||||
FlowENCollectionEnabled: e.FlowENCollectionEnabled,
|
||||
FlowDnsCollectionEnabled: e.FlowDnsCollectionEnabled,
|
||||
|
||||
33
relay/cmd/pprof.go
Normal file
33
relay/cmd/pprof.go
Normal file
@@ -0,0 +1,33 @@
|
||||
//go:build pprof
|
||||
// +build pprof
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func init() {
|
||||
addr := pprofAddr()
|
||||
go pprof(addr)
|
||||
}
|
||||
|
||||
func pprofAddr() string {
|
||||
listenAddr := os.Getenv("NB_PPROF_ADDR")
|
||||
if listenAddr == "" {
|
||||
return "localhost:6969"
|
||||
}
|
||||
|
||||
return listenAddr
|
||||
}
|
||||
|
||||
func pprof(listenAddr string) {
|
||||
log.Infof("listening pprof on: %s\n", listenAddr)
|
||||
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||
log.Fatalf("Failed to start pprof: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
@@ -111,7 +112,9 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -162,6 +162,12 @@ components:
|
||||
description: Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored.
|
||||
type: boolean
|
||||
example: true
|
||||
network_traffic_logs_groups:
|
||||
description: Limits traffic logging to these groups. If unset all peers are enabled.
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7m0
|
||||
network_traffic_packet_counter_enabled:
|
||||
description: Enables or disables network traffic packet counter. If enabled, network packets and their size will be counted and reported. (This can have an slight impact on performance)
|
||||
type: boolean
|
||||
@@ -169,6 +175,7 @@ components:
|
||||
required:
|
||||
- peer_approval_enabled
|
||||
- network_traffic_logs_enabled
|
||||
- network_traffic_logs_groups
|
||||
- network_traffic_packet_counter_enabled
|
||||
AccountRequest:
|
||||
type: object
|
||||
|
||||
@@ -260,6 +260,9 @@ type AccountExtraSettings struct {
|
||||
// NetworkTrafficLogsEnabled Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored.
|
||||
NetworkTrafficLogsEnabled bool `json:"network_traffic_logs_enabled"`
|
||||
|
||||
// NetworkTrafficLogsGroups Limits traffic logging to these groups. If unset all peers are enabled.
|
||||
NetworkTrafficLogsGroups []string `json:"network_traffic_logs_groups"`
|
||||
|
||||
// NetworkTrafficPacketCounterEnabled Enables or disables network traffic packet counter. If enabled, network packets and their size will be counted and reported. (This can have an slight impact on performance)
|
||||
NetworkTrafficPacketCounterEnabled bool `json:"network_traffic_packet_counter_enabled"`
|
||||
|
||||
|
||||
Reference in New Issue
Block a user