Compare commits

...

31 Commits

Author SHA1 Message Date
Zoltán Papp
29d6630686 Add MySQL initialization script and update Docker configuration
Original error was:
Access denied; you need (at least one of) the SUPER or SYSTEM_VARIABLES_ADMIN privilege(s) for this operation
2025-08-14 16:21:38 +02:00
Zoltán Papp
7247782dd7 Close ICE agent on failure to gather candidates or establish connection 2025-08-14 15:26:20 +02:00
Zoltan Papp
2122b65e84 Change log level from Debug to Warn for candidate gathering failure 2025-08-14 12:24:22 +02:00
Zoltán Papp
45c8f8b4ae Merge branch 'main' into fix/ice-handshake 2025-08-14 09:58:25 +02:00
Misha Bragin
e97f853909 Improve wording in the NetBird client app (#4316) 2025-08-13 22:03:48 +02:00
hakansa
70db8751d7 [client] Add --disable-update-settings flag to the service (#4335)
[client] Add --disable-update-settings flag to the service (#4335)
2025-08-13 21:05:12 +03:00
Zoltan Papp
86a00ab4af Fix Go tarball version in FreeBSD build configuration (#4339) 2025-08-13 13:52:11 +02:00
Zoltán Papp
fdc8cc3500 Merge branch 'main' into fix/ice-handshake 2025-08-13 10:47:54 +02:00
Zoltan Papp
3d4b502126 [server] Add health check HTTP endpoint for Relay server (#4297)
The health check endpoint listens on a dedicated HTTP server.
By default, it is available at 0.0.0.0:9000/health. This can be configured using the --health-listen-address flag.

The results are cached for 3 seconds to avoid excessive calls.

The health check performs the following:

Checks the number of active listeners.
Validates each listener via WebSocket and QUIC dials, including TLS certificate verification.
2025-08-13 10:40:04 +02:00
Zoltan Papp
48f9445af9 Optimize ICE guard reconnection logic and remove unnecessary wait function (#4323) 2025-08-13 10:38:59 +02:00
Bethuel Mmbaga
a4e8647aef [management] Enable flow groups (#4230)
Adds the ability to limit traffic events logging to specific peer groups
2025-08-13 00:00:40 +03:00
Viktor Liu
160b811e21 [client] Distinguish between NXDOMAIN and NODATA in the dns forwarder (#4321) 2025-08-12 15:59:42 +02:00
Viktor Liu
5e607cf4e9 [client] Skip dns upstream servers pointing to our dns server IP to prevent loops (#4330) 2025-08-12 15:41:23 +02:00
Viktor Liu
0fdb944058 [client] Create NRPT rules separately per domain (#4329) 2025-08-12 15:40:37 +02:00
Zoltan Papp
ccbabd9e2a Add pprof support for Relay server (#4203) 2025-08-12 12:24:24 +02:00
Zoltán Papp
bb0897dd85 Change session id from int to string 2025-08-08 01:00:18 +02:00
Zoltán Papp
5fa85aaaf7 Add log lines 2025-08-07 23:55:59 +02:00
Zoltán Papp
0fcb272f05 Fix worker restart
When the signal connection is lost, we do not recreate the gRPC instance but instead cancel the context.
As a result, the worker is stopped but not restarted.
2025-08-07 23:29:05 +02:00
Zoltán Papp
1c1af3f5be Fix proto import 2025-08-06 10:29:25 +02:00
Zoltán Papp
b283c7877c Merge branch 'main' into fix/ice-handshake 2025-08-06 09:55:41 +02:00
Zoltán Papp
90ec065c65 Fix worker cancellation 2025-08-05 11:09:11 +02:00
Zoltán Papp
8d9411e11e Update decryption worker to accept context in AddMsg method for avoid deadlock 2025-08-05 11:06:18 +02:00
Zoltán Papp
c207e0b0b3 Fix logging 2025-08-05 10:12:55 +02:00
Zoltán Papp
5194ba1580 Add SessionIDString method to OfferAnswer to avoid nil pointer exceptions 2025-08-05 10:08:18 +02:00
Zoltan Papp
a57cb82f82 Fix pointer logging 2025-08-05 09:54:06 +02:00
Zoltan Papp
d0ed9bb59d Fix pointer logging 2025-08-05 09:47:01 +02:00
Zoltan Papp
fa04b2ca77 Remove unnecessary return statement in ICE worker function 2025-08-05 00:34:09 +02:00
Zoltan Papp
d8ab5ceec6 Fix test 2025-08-05 00:15:14 +02:00
Zoltán Papp
daa27a3e24 Fix grammar 2025-08-04 20:35:55 +02:00
Zoltán Papp
b0c95d9bd2 Enhance ICE handshake with session ID management and improve message handling 2025-08-04 20:17:11 +02:00
Zoltán Papp
376fc954ee Refactor ICE connection handling for improved clarity and performance 2025-08-04 12:12:18 +02:00
78 changed files with 1778 additions and 534 deletions

View File

@@ -25,8 +25,7 @@ jobs:
release: "14.2" release: "14.2"
prepare: | prepare: |
pkg install -y curl pkgconf xorg pkg install -y curl pkgconf xorg
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1) GO_TARBALL="go1.23.12.freebsd-amd64.tar.gz"
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL" GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL" curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL" tar -C /usr/local -vxzf "$GO_TARBALL"

View File

@@ -47,6 +47,8 @@ jobs:
--health-timeout 5s --health-timeout 5s
ports: ports:
- 3306:3306 - 3306:3306
volumes:
- ./mysql-init.sql:/docker-entrypoint-initdb.d/init.sql
steps: steps:
- name: Set Database Connection String - name: Set Database Connection String
run: | run: |

View File

@@ -33,7 +33,7 @@ var (
var debugCmd = &cobra.Command{ var debugCmd = &cobra.Command{
Use: "debug", Use: "debug",
Short: "Debugging commands", Short: "Debugging commands",
Long: "Provides commands for debugging and logging control within the NetBird daemon.", Long: "Commands for debugging and logging within the NetBird daemon.",
} }
var debugBundleCmd = &cobra.Command{ var debugBundleCmd = &cobra.Command{

View File

@@ -14,7 +14,8 @@ import (
var downCmd = &cobra.Command{ var downCmd = &cobra.Command{
Use: "down", Use: "down",
Short: "down netbird connections", Short: "Disconnect from the NetBird network",
Long: "Disconnect the NetBird client from the network and management service. This will terminate all active connections with the remote peers.",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)

View File

@@ -31,7 +31,8 @@ func init() {
var loginCmd = &cobra.Command{ var loginCmd = &cobra.Command{
Use: "login", Use: "login",
Short: "login to the NetBird Management Service (first run)", Short: "Log in to the NetBird network",
Long: "Log in to the NetBird network using a setup key or SSO",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setEnvAndFlags(cmd); err != nil { if err := setEnvAndFlags(cmd); err != nil {
return fmt.Errorf("set env and flags: %v", err) return fmt.Errorf("set env and flags: %v", err)

View File

@@ -14,7 +14,8 @@ import (
var logoutCmd = &cobra.Command{ var logoutCmd = &cobra.Command{
Use: "deregister", Use: "deregister",
Aliases: []string{"logout"}, Aliases: []string{"logout"},
Short: "deregister from the NetBird Management Service and delete peer", Short: "Deregister from the NetBird management service and delete this peer",
Long: "This command will deregister the current peer from the NetBird management service and all associated configuration. Use with caution as this will remove the peer from the network.",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)

View File

@@ -15,7 +15,7 @@ var appendFlag bool
var networksCMD = &cobra.Command{ var networksCMD = &cobra.Command{
Use: "networks", Use: "networks",
Aliases: []string{"routes"}, Aliases: []string{"routes"},
Short: "Manage networks", Short: "Manage connections to NetBird Networks and Resources",
Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`, Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`,
} }

View File

@@ -16,13 +16,13 @@ import (
var profileCmd = &cobra.Command{ var profileCmd = &cobra.Command{
Use: "profile", Use: "profile",
Short: "manage NetBird profiles", Short: "Manage NetBird client profiles",
Long: `Manage NetBird profiles, allowing you to list, switch, and remove profiles.`, Long: `Commands to list, add, remove, and switch profiles. Profiles allow you to maintain different accounts in one client app.`,
} }
var profileListCmd = &cobra.Command{ var profileListCmd = &cobra.Command{
Use: "list", Use: "list",
Short: "list all profiles", Short: "List all profiles",
Long: `List all available profiles in the NetBird client.`, Long: `List all available profiles in the NetBird client.`,
Aliases: []string{"ls"}, Aliases: []string{"ls"},
RunE: listProfilesFunc, RunE: listProfilesFunc,
@@ -30,7 +30,7 @@ var profileListCmd = &cobra.Command{
var profileAddCmd = &cobra.Command{ var profileAddCmd = &cobra.Command{
Use: "add <profile_name>", Use: "add <profile_name>",
Short: "add a new profile", Short: "Add a new profile",
Long: `Add a new profile to the NetBird client. The profile name must be unique.`, Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: addProfileFunc, RunE: addProfileFunc,
@@ -38,16 +38,16 @@ var profileAddCmd = &cobra.Command{
var profileRemoveCmd = &cobra.Command{ var profileRemoveCmd = &cobra.Command{
Use: "remove <profile_name>", Use: "remove <profile_name>",
Short: "remove a profile", Short: "Remove a profile",
Long: `Remove a profile from the NetBird client. The profile must not be active.`, Long: `Remove a profile from the NetBird client. The profile must not be inactive.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: removeProfileFunc, RunE: removeProfileFunc,
} }
var profileSelectCmd = &cobra.Command{ var profileSelectCmd = &cobra.Command{
Use: "select <profile_name>", Use: "select <profile_name>",
Short: "select a profile", Short: "Select a profile",
Long: `Select a profile to be the active profile in the NetBird client. The profile must exist.`, Long: `Make the specified profile active. This will switch the client to use the selected profile's configuration.`,
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
RunE: selectProfileFunc, RunE: selectProfileFunc,
} }

View File

@@ -73,6 +73,7 @@ var (
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
lazyConnEnabled bool lazyConnEnabled bool
profilesDisabled bool profilesDisabled bool
updateSettingsDisabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",

View File

@@ -19,7 +19,7 @@ import (
var serviceCmd = &cobra.Command{ var serviceCmd = &cobra.Command{
Use: "service", Use: "service",
Short: "manages NetBird service", Short: "Manage the NetBird daemon service",
} }
var ( var (
@@ -42,7 +42,8 @@ func init() {
} }
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd) serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile.") serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` + serviceEnvDesc := `Sets extra environment variables for the service. ` +

View File

@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
} }
} }
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled) serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
if err := serverInstance.Start(); err != nil { if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err) log.Fatalf("failed to start daemon: %v", err)
} }

View File

@@ -49,6 +49,14 @@ func buildServiceArguments() []string {
args = append(args, "--log-file", logFile) args = append(args, "--log-file", logFile)
} }
if profilesDisabled {
args = append(args, "--disable-profiles")
}
if updateSettingsDisabled {
args = append(args, "--disable-update-settings")
}
return args return args
} }
@@ -99,7 +107,7 @@ func createServiceConfigForInstall() (*service.Config, error) {
var installCmd = &cobra.Command{ var installCmd = &cobra.Command{
Use: "install", Use: "install",
Short: "installs NetBird service", Short: "Install NetBird service",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
if err := setupServiceCommand(cmd); err != nil { if err := setupServiceCommand(cmd); err != nil {
return err return err

View File

@@ -40,7 +40,7 @@ var sshCmd = &cobra.Command{
return nil return nil
}, },
Short: "connect to a remote SSH server", Short: "Connect to a remote SSH server",
RunE: func(cmd *cobra.Command, args []string) error { RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd) SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd) SetFlagsFromEnvVars(cmd)

View File

@@ -32,7 +32,8 @@ var (
var statusCmd = &cobra.Command{ var statusCmd = &cobra.Command{
Use: "status", Use: "status",
Short: "status of the Netbird Service", Short: "Display NetBird client status",
Long: "Display the current status of the NetBird client, including connection status, peer information, and network details.",
RunE: statusFunc, RunE: statusFunc,
} }

View File

@@ -11,6 +11,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -26,8 +27,8 @@ import (
clientProto "github.com/netbirdio/netbird/client/proto" clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server" client "github.com/netbirdio/netbird/client/server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
mgmt "github.com/netbirdio/netbird/management/server" mgmt "github.com/netbirdio/netbird/management/server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
sigProto "github.com/netbirdio/netbird/shared/signal/proto" sigProto "github.com/netbirdio/netbird/shared/signal/proto"
sig "github.com/netbirdio/netbird/signal/server" sig "github.com/netbirdio/netbird/signal/server"
) )
@@ -97,6 +98,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
settingsMockManager.EXPECT(). settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
@@ -108,7 +110,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
t.Fatal(err) t.Fatal(err)
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{}) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -134,7 +136,7 @@ func startClientDaemon(
s := grpc.NewServer() s := grpc.NewServer()
server := client.New(ctx, server := client.New(ctx,
"", "", false) "", "", false, false)
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -53,7 +53,8 @@ var (
upCmd = &cobra.Command{ upCmd = &cobra.Command{
Use: "up", Use: "up",
Short: "install, login and start NetBird client", Short: "Connect to the NetBird network",
Long: "Connect to the NetBird network using the provided setup key or SSO auth. This command will bring up the WireGuard interface, connect to the management server, and establish peer-to-peer connections with other peers in the network if required.",
RunE: upFunc, RunE: upFunc,
} }
) )

View File

@@ -9,7 +9,7 @@ import (
var ( var (
versionCmd = &cobra.Command{ versionCmd = &cobra.Command{
Use: "version", Use: "version",
Short: "prints NetBird version", Short: "Print the NetBird's client application version",
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
cmd.SetOut(cmd.OutOrStdout()) cmd.SetOut(cmd.OutOrStdout())
cmd.Println(version.NetbirdVersion()) cmd.Println(version.NetbirdVersion())

View File

@@ -176,4 +176,3 @@ nameserver 192.168.0.1
t.Errorf("unexpected resolv.conf content: %v", cfg) t.Errorf("unexpected resolv.conf content: %v", cfg)
} }
} }

View File

@@ -67,6 +67,7 @@ type registryConfigurator struct {
guid string guid string
routingAll bool routingAll bool
gpo bool gpo bool
nrptEntryCount int
} }
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
@@ -177,7 +178,11 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid, GPO: r.gpo}); err != nil { if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err) log.Errorf("failed to update shutdown state: %s", err)
} }
@@ -193,13 +198,24 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
} }
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
if err := r.addDNSMatchPolicy(matchDomains, config.ServerIP); err != nil { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
if err != nil {
return fmt.Errorf("add dns match policy: %w", err) return fmt.Errorf("add dns match policy: %w", err)
} }
r.nrptEntryCount = count
} else { } else {
if err := r.removeDNSMatchPolicies(); err != nil { if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("remove dns match policies: %w", err) return fmt.Errorf("remove dns match policies: %w", err)
} }
r.nrptEntryCount = 0
}
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
if err := r.updateSearchDomains(searchDomains); err != nil { if err := r.updateSearchDomains(searchDomains); err != nil {
@@ -220,28 +236,34 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
return nil return nil
} }
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) error { func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
for i, domain := range domains {
policyPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
if r.gpo { if r.gpo {
if err := r.configureDNSPolicy(gpoDnsPolicyConfigMatchPath, domains, ip); err != nil { policyPath = fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
return fmt.Errorf("configure GPO DNS policy: %w", err)
} }
singleDomain := []string{domain}
if err := r.configureDNSPolicy(policyPath, singleDomain, ip); err != nil {
return i, fmt.Errorf("configure DNS policy for domain %s: %w", domain, err)
}
log.Debugf("added NRPT entry for domain: %s", domain)
}
if r.gpo {
if err := refreshGroupPolicy(); err != nil { if err := refreshGroupPolicy(); err != nil {
log.Warnf("failed to refresh group policy: %v", err) log.Warnf("failed to refresh group policy: %v", err)
} }
} else {
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
return fmt.Errorf("configure local DNS policy: %w", err)
}
} }
log.Infof("added %d match domains. Domain list: %s", len(domains), domains) log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
return nil return len(domains), nil
} }
// configureDNSPolicy handles the actual configuration of a DNS policy at the specified path
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error { func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(policyPath); err != nil {
return fmt.Errorf("remove existing dns policy: %w", err) return fmt.Errorf("remove existing dns policy: %w", err)
@@ -374,12 +396,25 @@ func (r *registryConfigurator) restoreHostDNS() error {
func (r *registryConfigurator) removeDNSMatchPolicies() error { func (r *registryConfigurator) removeDNSMatchPolicies() error {
var merr *multierror.Error var merr *multierror.Error
// Try to remove the base entries (for backward compatibility)
if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil { if err := removeRegistryKeyFromDNSPolicyConfig(dnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local registry key: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove local base entry: %w", err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO base entry: %w", err))
} }
if err := removeRegistryKeyFromDNSPolicyConfig(gpoDnsPolicyConfigMatchPath); err != nil { for i := 0; i < r.nrptEntryCount; i++ {
merr = multierror.Append(merr, fmt.Errorf("remove GPO registry key: %w", err)) localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
if err := removeRegistryKeyFromDNSPolicyConfig(localPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove local entry %d: %w", i, err))
}
if err := removeRegistryKeyFromDNSPolicyConfig(gpoPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove GPO entry %d: %w", i, err))
}
} }
if err := refreshGroupPolicy(); err != nil { if err := refreshGroupPolicy(); err != nil {

View File

@@ -695,6 +695,12 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String()) ns.IP.String(), ns.NSType.String(), nbdns.UDPNameServerType.String())
continue continue
} }
if ns.IP == s.service.RuntimeIP() {
log.Warnf("skipping nameserver %s as it matches our DNS server IP, preventing potential loop", ns.IP)
continue
}
handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort()) handler.upstreamServers = append(handler.upstreamServers, ns.AddrPort())
} }

View File

@@ -2056,3 +2056,124 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal") assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain) assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
} }
func TestDNSLoopPrevention(t *testing.T) {
wgInterface := &mocWGIface{}
service := NewServiceViaMemory(wgInterface)
dnsServerIP := service.RuntimeIP()
server := &DefaultServer{
ctx: context.Background(),
wgInterface: wgInterface,
service: service,
localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
}
tests := []struct {
name string
nsGroups []*nbdns.NameServerGroup
expectedHandlers int
expectedServers []netip.Addr
shouldFilterOwnIP bool
}{
{
name: "FilterOwnDNSServerIP",
nsGroups: []*nbdns.NameServerGroup{
{
Primary: true,
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{},
},
},
expectedHandlers: 1,
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
shouldFilterOwnIP: true,
},
{
name: "AllServersFiltered",
nsGroups: []*nbdns.NameServerGroup{
{
Primary: false,
NameServers: []nbdns.NameServer{
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{"example.com"},
},
},
expectedHandlers: 0,
expectedServers: []netip.Addr{},
shouldFilterOwnIP: true,
},
{
name: "MixedServersWithOwnIP",
nsGroups: []*nbdns.NameServerGroup{
{
Primary: false,
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: dnsServerIP, NSType: nbdns.UDPNameServerType, Port: 53}, // duplicate
},
Domains: []string{"test.com"},
},
},
expectedHandlers: 1,
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
shouldFilterOwnIP: true,
},
{
name: "NoOwnIPInList",
nsGroups: []*nbdns.NameServerGroup{
{
Primary: true,
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53},
{IP: netip.MustParseAddr("1.1.1.1"), NSType: nbdns.UDPNameServerType, Port: 53},
},
Domains: []string{},
},
},
expectedHandlers: 1,
expectedServers: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")},
shouldFilterOwnIP: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
muxUpdates, err := server.buildUpstreamHandlerUpdate(tt.nsGroups)
assert.NoError(t, err)
assert.Len(t, muxUpdates, tt.expectedHandlers)
if tt.expectedHandlers > 0 {
handler := muxUpdates[0].handler.(*upstreamResolver)
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
if tt.shouldFilterOwnIP {
for _, upstream := range handler.upstreamServers {
assert.NotEqual(t, dnsServerIP, upstream.Addr())
}
}
for _, expected := range tt.expectedServers {
found := false
for _, upstream := range handler.upstreamServers {
if upstream.Addr() == expected {
found = true
break
}
}
assert.True(t, found, "Expected server %s not found", expected)
}
}
})
}
}

View File

@@ -7,6 +7,7 @@ import (
type ShutdownState struct { type ShutdownState struct {
Guid string Guid string
GPO bool GPO bool
NRPTEntryCount int
} }
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
@@ -17,6 +18,7 @@ func (s *ShutdownState) Cleanup() error {
manager := &registryConfigurator{ manager := &registryConfigurator{
guid: s.Guid, guid: s.Guid,
gpo: s.GPO, gpo: s.GPO,
nrptEntryCount: s.NRPTEntryCount,
} }
if err := manager.restoreUncleanShutdownDNS(); err != nil { if err := manager.restoreUncleanShutdownDNS(); err != nil {

View File

@@ -165,7 +165,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
defer cancel() defer cancel()
ips, err := f.resolver.LookupNetIP(ctx, network, domain) ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil { if err != nil {
f.handleDNSError(w, query, resp, domain, err) f.handleDNSError(ctx, w, question, resp, domain, err)
return nil return nil
} }
@@ -244,20 +244,57 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
} }
} }
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
//
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
// only handles A/AAAA queries and returns NOTIMP for other types.
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
// Try querying for a different record type to see if the domain exists
// If the original query was for AAAA, try A. If it was for A, try AAAA.
// This helps distinguish between NXDOMAIN and NODATA.
var alternativeNetwork string
switch originalQtype {
case dns.TypeAAAA:
alternativeNetwork = "ip4"
case dns.TypeA:
alternativeNetwork = "ip6"
default:
resp.Rcode = dns.RcodeNameError
return
}
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
// Alternative query also returned not found - domain truly doesn't exist
resp.Rcode = dns.RcodeNameError
return
}
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
resp.Rcode = dns.RcodeSuccess
return
}
// Alternative query succeeded - domain exists but has no records of this type
resp.Rcode = dns.RcodeSuccess
}
// handleDNSError processes DNS lookup errors and sends an appropriate error response // handleDNSError processes DNS lookup errors and sends an appropriate error response
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) { func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
var dnsErr *net.DNSError var dnsErr *net.DNSError
switch { switch {
case errors.As(err, &dnsErr): case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound { if dnsErr.IsNotFound {
// Pass through NXDOMAIN f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
resp.Rcode = dns.RcodeNameError
} }
if dnsErr.Server != "" { if dnsErr.Server != "" {
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err) log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
} else { } else {
log.Warnf(errResolveFailed, domain, err) log.Warnf(errResolveFailed, domain, err)
} }

View File

@@ -3,6 +3,7 @@ package dnsfwd
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"strings" "strings"
"testing" "testing"
@@ -16,8 +17,8 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
func Test_getMatchingEntries(t *testing.T) { func Test_getMatchingEntries(t *testing.T) {
@@ -708,6 +709,131 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
assert.Len(t, matches, 3, "Should match 3 patterns") assert.Len(t, matches, 3, "Should match 3 patterns")
} }
// TestDNSForwarder_NodataVsNxdomain tests that the forwarder correctly distinguishes
// between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of that type)
func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res", Set: set}}
forwarder.UpdateDomains(entries)
tests := []struct {
name string
queryType uint16
setupMocks func()
expectedCode int
expectNoAnswer bool // true if we expect NOERROR with empty answer (NODATA case)
description string
}{
{
name: "domain exists but no AAAA records (NODATA)",
queryType: dns.TypeAAAA,
setupMocks: func() {
// First query for AAAA returns not found
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
// Check query for A records succeeds (domain exists)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
},
expectedCode: dns.RcodeSuccess,
expectNoAnswer: true,
description: "Should return NOERROR when domain exists but has no records of requested type",
},
{
name: "domain exists but no A records (NODATA)",
queryType: dns.TypeA,
setupMocks: func() {
// First query for A returns not found
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
// Check query for AAAA records succeeds (domain exists)
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
Return([]netip.Addr{netip.MustParseAddr("2001:db8::1")}, nil).Once()
},
expectedCode: dns.RcodeSuccess,
expectNoAnswer: true,
description: "Should return NOERROR when domain exists but has no A records",
},
{
name: "domain doesn't exist (NXDOMAIN)",
queryType: dns.TypeA,
setupMocks: func() {
// First query for A returns not found
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
// Check query for AAAA also returns not found (domain doesn't exist)
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
},
expectedCode: dns.RcodeNameError,
expectNoAnswer: true,
description: "Should return NXDOMAIN when domain doesn't exist at all",
},
{
name: "domain exists with records (normal success)",
queryType: dns.TypeA,
setupMocks: func() {
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
// Expect firewall update for successful resolution
expectedPrefix := netip.PrefixFrom(netip.MustParseAddr("1.2.3.4"), 32)
mockFirewall.On("UpdateSet", set, []netip.Prefix{expectedPrefix}).Return(nil).Once()
},
expectedCode: dns.RcodeSuccess,
expectNoAnswer: false,
description: "Should return NOERROR with answer when records exist",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset mock expectations
mockResolver.ExpectedCalls = nil
mockResolver.Calls = nil
mockFirewall.ExpectedCalls = nil
mockFirewall.Calls = nil
tt.setupMocks()
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
// If a response was returned, it means it should be written (happens in wrapper functions)
if resp != nil && writtenResp == nil {
writtenResp = resp
}
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
if tt.expectNoAnswer {
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
}
mockResolver.AssertExpectations(t)
})
}
}
func TestDNSForwarder_EmptyQuery(t *testing.T) { func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions // Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})

View File

@@ -55,11 +55,11 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client" mgm "github.com/netbirdio/netbird/shared/management/client"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
sProto "github.com/netbirdio/netbird/shared/signal/proto" sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@@ -254,6 +254,7 @@ func NewEngine(
} }
engine.stateManager = statemanager.New(path) engine.stateManager = statemanager.New(path)
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine return engine
} }
@@ -1330,52 +1331,17 @@ func (e *Engine) receiveSignalEvents() {
} }
switch msg.GetBody().Type { switch msg.GetBody().Type {
case sProto.Body_OFFER: case sProto.Body_OFFER, sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg) offerAnswer, err := convertToOfferAnswer(msg)
if err != nil { if err != nil {
return err return err
} }
var rosenpassPubKey []byte if msg.Body.Type == sProto.Body_OFFER {
rosenpassAddr := "" conn.OnRemoteOffer(*offerAnswer)
if msg.GetBody().GetRosenpassConfig() != nil { } else {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey() conn.OnRemoteAnswer(*offerAnswer)
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
} }
conn.OnRemoteOffer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {
return err
}
var rosenpassPubKey []byte
rosenpassAddr := ""
if msg.GetBody().GetRosenpassConfig() != nil {
rosenpassPubKey = msg.GetBody().GetRosenpassConfig().GetRosenpassPubKey()
rosenpassAddr = msg.GetBody().GetRosenpassConfig().GetRosenpassServerAddr()
}
conn.OnRemoteAnswer(peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_CANDIDATE: case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload) candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
if err != nil { if err != nil {
@@ -2073,3 +2039,44 @@ func createFile(path string) error {
} }
return file.Close() return file.Close()
} }
func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
remoteCred, err := signal.UnMarshalCredential(msg)
if err != nil {
return nil, err
}
var (
rosenpassPubKey []byte
rosenpassAddr string
)
if cfg := msg.GetBody().GetRosenpassConfig(); cfg != nil {
rosenpassPubKey = cfg.GetRosenpassPubKey()
rosenpassAddr = cfg.GetRosenpassServerAddr()
}
// Handle optional SessionID
var sessionID *peer.ICESessionID
if sessionBytes := msg.GetBody().GetSessionId(); sessionBytes != nil {
if id, err := peer.ICESessionIDFromBytes(sessionBytes); err != nil {
log.Warnf("Invalid session ID in message: %v", err)
sessionID = nil // Set to nil if conversion fails
} else {
sessionID = &id
}
}
offerAnswer := peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
Pwd: remoteCred.Pwd,
},
WgListenPort: int(msg.GetBody().GetWgListenPort()),
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
SessionID: sessionID,
}
return &offerAnswer, nil
}

View File

@@ -27,6 +27,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
@@ -1564,13 +1565,14 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err

View File

@@ -24,8 +24,8 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/id" "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker" "github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
@@ -200,19 +200,11 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.wg.Add(1) conn.wg.Add(1)
go func() { go func() {
defer conn.wg.Done() defer conn.wg.Done()
conn.waitInitialRandomSleepTime(conn.ctx) conn.waitInitialRandomSleepTime(conn.ctx)
conn.semaphore.Done(conn.ctx) conn.semaphore.Done(conn.ctx)
conn.dumpState.SendOffer()
if err := conn.handshaker.sendOffer(); err != nil {
conn.Log.Errorf("failed to send initial offer: %v", err)
}
conn.wg.Add(1)
go func() {
conn.guard.Start(conn.ctx, conn.onGuardEvent) conn.guard.Start(conn.ctx, conn.onGuardEvent)
conn.wg.Done()
}()
}() }()
conn.opened = true conn.opened = true
return nil return nil
@@ -274,10 +266,10 @@ func (conn *Conn) Close(signalToRemote bool) {
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool { func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
conn.dumpState.RemoteAnswer() conn.dumpState.RemoteAnswer()
conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteAnswer, priority: %s, status ICE: %s, status relay: %s", conn.currentConnPriority, conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteAnswer(answer) conn.handshaker.OnRemoteAnswer(answer)
} }
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
@@ -296,10 +288,10 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
conn.onDisconnected = handler conn.onDisconnected = handler
} }
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool { func (conn *Conn) OnRemoteOffer(offer OfferAnswer) {
conn.dumpState.RemoteOffer() conn.dumpState.RemoteOffer()
conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay) conn.Log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteOffer(offer) conn.handshaker.OnRemoteOffer(offer)
} }
// WgConfig returns the WireGuard config // WgConfig returns the WireGuard config
@@ -548,7 +540,6 @@ func (conn *Conn) onRelayDisconnected() {
} }
func (conn *Conn) onGuardEvent() { func (conn *Conn) onGuardEvent() {
conn.Log.Debugf("send offer to peer")
conn.dumpState.SendOffer() conn.dumpState.SendOffer()
if err := conn.handshaker.SendOffer(); err != nil { if err := conn.handshaker.SendOffer(); err != nil {
conn.Log.Errorf("failed to send offer: %v", err) conn.Log.Errorf("failed to send offer: %v", err)
@@ -672,7 +663,7 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
} }
}() }()
if conn.statusICE.Get() == worker.StatusDisconnected { if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false return false
} }

View File

@@ -1,9 +1,9 @@
package peer package peer
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"sync"
"testing" "testing"
"time" "time"
@@ -79,16 +79,13 @@ func TestConn_OnRemoteOffer(t *testing.T) {
return return
} }
wg := sync.WaitGroup{} onNewOffeChan := make(chan struct{})
wg.Add(2)
go func() {
<-conn.handshaker.remoteOffersCh
wg.Done()
}()
go func() { conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
for { onNewOffeChan <- struct{}{}
accepted := conn.OnRemoteOffer(OfferAnswer{ })
conn.OnRemoteOffer(OfferAnswer{
IceCredentials: IceCredentials{ IceCredentials: IceCredentials{
UFrag: "test", UFrag: "test",
Pwd: "test", Pwd: "test",
@@ -96,14 +93,16 @@ func TestConn_OnRemoteOffer(t *testing.T) {
WgListenPort: 0, WgListenPort: 0,
Version: "", Version: "",
}) })
if accepted {
wg.Done()
return
}
}
}()
wg.Wait() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
select {
case <-onNewOffeChan:
// success
case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out")
}
} }
func TestConn_OnRemoteAnswer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) {
@@ -119,16 +118,13 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
return return
} }
wg := sync.WaitGroup{} onNewOffeChan := make(chan struct{})
wg.Add(2)
go func() {
<-conn.handshaker.remoteAnswerCh
wg.Done()
}()
go func() { conn.handshaker.AddOnNewOfferListener(func(remoteOfferAnswer *OfferAnswer) {
for { onNewOffeChan <- struct{}{}
accepted := conn.OnRemoteAnswer(OfferAnswer{ })
conn.OnRemoteAnswer(OfferAnswer{
IceCredentials: IceCredentials{ IceCredentials: IceCredentials{
UFrag: "test", UFrag: "test",
Pwd: "test", Pwd: "test",
@@ -136,14 +132,15 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
WgListenPort: 0, WgListenPort: 0,
Version: "", Version: "",
}) })
if accepted { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
wg.Done() defer cancel()
return
}
}
}()
wg.Wait() select {
case <-onNewOffeChan:
// success
case <-ctx.Done():
t.Error("expected to receive a new offer notification, but timed out")
}
} }
func TestConn_presharedKey(t *testing.T) { func TestConn_presharedKey(t *testing.T) {

View File

@@ -19,7 +19,6 @@ type isConnectedFunc func() bool
// - Relayed connection disconnected // - Relayed connection disconnected
// - ICE candidate changes // - ICE candidate changes
type Guard struct { type Guard struct {
Reconnect chan struct{}
log *log.Entry log *log.Entry
isConnectedOnAllWay isConnectedFunc isConnectedOnAllWay isConnectedFunc
timeout time.Duration timeout time.Duration
@@ -30,7 +29,6 @@ type Guard struct {
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{ return &Guard{
Reconnect: make(chan struct{}, 1),
log: log, log: log,
isConnectedOnAllWay: isConnectedFn, isConnectedOnAllWay: isConnectedFn,
timeout: timeout, timeout: timeout,
@@ -41,6 +39,7 @@ func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Durati
} }
func (g *Guard) Start(ctx context.Context, eventCallback func()) { func (g *Guard) Start(ctx context.Context, eventCallback func()) {
g.log.Infof("starting guard for reconnection with MaxInterval: %s", g.timeout)
g.reconnectLoopWithRetry(ctx, eventCallback) g.reconnectLoopWithRetry(ctx, eventCallback)
} }
@@ -61,17 +60,14 @@ func (g *Guard) SetICEConnDisconnected() {
// reconnectLoopWithRetry periodically check the connection status. // reconnectLoopWithRetry periodically check the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported // Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
waitForInitialConnectionTry(ctx)
srReconnectedChan := g.srWatcher.NewListener() srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan) defer g.srWatcher.RemoveListener(srReconnectedChan)
ticker := g.prepareExponentTicker(ctx) ticker := g.initialTicker(ctx)
defer ticker.Stop() defer ticker.Stop()
tickerChannel := ticker.C tickerChannel := ticker.C
g.log.Infof("start reconnect loop...")
for { for {
select { select {
case t := <-tickerChannel: case t := <-tickerChannel:
@@ -85,7 +81,6 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
if !g.isConnectedOnAllWay() { if !g.isConnectedOnAllWay() {
callback() callback()
} }
case <-g.relayedConnDisconnected: case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, reset reconnection ticker") g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop() ticker.Stop()
@@ -111,6 +106,20 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
} }
} }
// initialTicker give chance to the peer to establish the initial connection.
func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 3 * time.Second,
RandomizationFactor: 0.1,
Multiplier: 2,
MaxInterval: g.timeout,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
return backoff.NewTicker(bo)
}
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{ bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
@@ -126,13 +135,3 @@ func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
return ticker return ticker
} }
// Give chance to the peer to establish the initial connection.
// With it, we can decrease to send necessary offer
func waitForInitialConnectionTry(ctx context.Context) {
select {
case <-ctx.Done():
return
case <-time.After(3 * time.Second):
}
}

View File

@@ -39,6 +39,15 @@ type OfferAnswer struct {
// relay server address // relay server address
RelaySrvAddress string RelaySrvAddress string
// SessionID is the unique identifier of the session, used to discard old messages
SessionID *ICESessionID
}
func (oa *OfferAnswer) SessionIDString() string {
if oa.SessionID == nil {
return "unknown"
}
return oa.SessionID.String()
} }
type Handshaker struct { type Handshaker struct {
@@ -74,21 +83,25 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
func (h *Handshaker) Listen(ctx context.Context) { func (h *Handshaker) Listen(ctx context.Context) {
for { for {
h.log.Info("wait for remote offer confirmation") select {
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation(ctx) case remoteOfferAnswer := <-h.remoteOffersCh:
if err != nil { // received confirmation from the remote peer -> ready to proceed
var connectionClosedError *ConnectionClosedError if err := h.sendAnswer(); err != nil {
if errors.As(err, &connectionClosedError) { h.log.Errorf("failed to send remote offer confirmation: %s", err)
h.log.Info("exit from handshaker")
return
}
h.log.Errorf("failed to received remote offer confirmation: %s", err)
continue continue
} }
h.log.Infof("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
for _, listener := range h.onNewOfferListeners { for _, listener := range h.onNewOfferListeners {
go listener(remoteOfferAnswer) listener(&remoteOfferAnswer)
}
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
for _, listener := range h.onNewOfferListeners {
listener(&remoteOfferAnswer)
}
case <-ctx.Done():
h.log.Infof("stop listening for remote offers and answers")
return
} }
} }
} }
@@ -101,43 +114,27 @@ func (h *Handshaker) SendOffer() error {
// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool { func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) {
select { select {
case h.remoteOffersCh <- offer: case h.remoteOffersCh <- offer:
return true return
default: default:
h.log.Warnf("OnRemoteOffer skipping message because is not ready") h.log.Warnf("skipping remote offer message because receiver not ready")
// connection might not be ready yet to receive so we ignore the message // connection might not be ready yet to receive so we ignore the message
return false return
} }
} }
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool { func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) {
select { select {
case h.remoteAnswerCh <- answer: case h.remoteAnswerCh <- answer:
return true return
default: default:
// connection might not be ready yet to receive so we ignore the message // connection might not be ready yet to receive so we ignore the message
h.log.Debugf("OnRemoteAnswer skipping message because is not ready") h.log.Warnf("skipping remote answer message because receiver not ready")
return false return
}
}
func (h *Handshaker) waitForRemoteOfferConfirmation(ctx context.Context) (*OfferAnswer, error) {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed
if err := h.sendAnswer(); err != nil {
return nil, err
}
return &remoteOfferAnswer, nil
case remoteOfferAnswer := <-h.remoteAnswerCh:
return &remoteOfferAnswer, nil
case <-ctx.Done():
// closed externally
return nil, NewConnectionClosedError(h.config.Key)
} }
} }
@@ -147,43 +144,34 @@ func (h *Handshaker) sendOffer() error {
return ErrSignalIsNotReady return ErrSignalIsNotReady
} }
iceUFrag, icePwd := h.ice.GetLocalUserCredentials() offer := h.buildOfferAnswer()
offer := OfferAnswer{ h.log.Infof("sending offer with serial: %s", offer.SessionIDString())
IceCredentials: IceCredentials{iceUFrag, icePwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr,
}
addr, err := h.relay.RelayInstanceAddress()
if err == nil {
offer.RelaySrvAddress = addr
}
return h.signaler.SignalOffer(offer, h.config.Key) return h.signaler.SignalOffer(offer, h.config.Key)
} }
func (h *Handshaker) sendAnswer() error { func (h *Handshaker) sendAnswer() error {
h.log.Infof("sending answer") answer := h.buildOfferAnswer()
uFrag, pwd := h.ice.GetLocalUserCredentials() h.log.Infof("sending answer with serial: %s", answer.SessionIDString())
return h.signaler.SignalAnswer(answer, h.config.Key)
}
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer := OfferAnswer{ answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd}, IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort, WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(), Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey, RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr, RosenpassAddr: h.config.RosenpassConfig.Addr,
SessionID: &sid,
} }
addr, err := h.relay.RelayInstanceAddress()
if err == nil { if addr, err := h.relay.RelayInstanceAddress(); err == nil {
answer.RelaySrvAddress = addr answer.RelaySrvAddress = addr
} }
err = h.signaler.SignalAnswer(answer, h.config.Key) return answer
if err != nil {
return err
}
return nil
} }

View File

@@ -0,0 +1,47 @@
package peer
import (
"crypto/rand"
"encoding/hex"
"fmt"
"io"
)
const sessionIDSize = 5
type ICESessionID string
// NewICESessionID generates a new session ID for distinguishing sessions
func NewICESessionID() (ICESessionID, error) {
b := make([]byte, sessionIDSize)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
return "", fmt.Errorf("failed to generate session ID: %w", err)
}
return ICESessionID(hex.EncodeToString(b)), nil
}
func ICESessionIDFromBytes(b []byte) (ICESessionID, error) {
if len(b) != sessionIDSize {
return "", fmt.Errorf("invalid session ID length: %d", len(b))
}
return ICESessionID(hex.EncodeToString(b)), nil
}
// Bytes returns the raw bytes of the session ID for protobuf serialization
func (id ICESessionID) Bytes() ([]byte, error) {
if len(id) == 0 {
return nil, fmt.Errorf("ICE session ID is empty")
}
b, err := hex.DecodeString(string(id))
if err != nil {
return nil, fmt.Errorf("invalid ICE session ID encoding: %w", err)
}
if len(b) != sessionIDSize {
return nil, fmt.Errorf("invalid ICE session ID length: expected %d bytes, got %d", sessionIDSize, len(b))
}
return b, nil
}
func (id ICESessionID) String() string {
return string(id)
}

View File

@@ -2,6 +2,7 @@ package peer
import ( import (
"github.com/pion/ice/v3" "github.com/pion/ice/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
@@ -45,6 +46,10 @@ func (s *Signaler) Ready() bool {
// SignalOfferAnswer signals either an offer or an answer to remote peer // SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error { func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
}
msg, err := signal.MarshalCredential( msg, err := signal.MarshalCredential(
s.wgPrivateKey, s.wgPrivateKey,
offerAnswer.WgListenPort, offerAnswer.WgListenPort,
@@ -56,13 +61,13 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
bodyType, bodyType,
offerAnswer.RosenpassPubKey, offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassAddr, offerAnswer.RosenpassAddr,
offerAnswer.RelaySrvAddress) offerAnswer.RelaySrvAddress,
sessionIDBytes)
if err != nil { if err != nil {
return err return err
} }
err = s.signal.Send(msg) if err = s.signal.Send(msg); err != nil {
if err != nil {
return err return err
} }

View File

@@ -43,6 +43,16 @@ type WorkerICE struct {
hasRelayOnLocally bool hasRelayOnLocally bool
agent *ice.Agent agent *ice.Agent
agentDialerCancel context.CancelFunc
agentConnecting bool // while it is true, drop all incoming offers
lastSuccess time.Time // with this avoid the too frequent ICE agent recreation
// remoteSessionID represents the peer's session identifier from the latest remote offer.
remoteSessionID ICESessionID
// sessionID is used to track the current session ID of the ICE agent
// increase by one when disconnecting the agent
// with it the remote peer can discard the already deprecated offer/answer
// Without it the remote peer may recreate a workable ICE connection
sessionID ICESessionID
muxAgent sync.Mutex muxAgent sync.Mutex
StunTurn []*stun.URI StunTurn []*stun.URI
@@ -57,6 +67,11 @@ type WorkerICE struct {
} }
func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) { func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *Conn, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool) (*WorkerICE, error) {
sessionID, err := NewICESessionID()
if err != nil {
return nil, err
}
w := &WorkerICE{ w := &WorkerICE{
ctx: ctx, ctx: ctx,
log: log, log: log,
@@ -67,6 +82,7 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
hasRelayOnLocally: hasRelayOnLocally, hasRelayOnLocally: hasRelayOnLocally,
lastKnownState: ice.ConnectionStateDisconnected, lastKnownState: ice.ConnectionStateDisconnected,
sessionID: sessionID,
} }
localUfrag, localPwd, err := icemaker.GenerateICECredentials() localUfrag, localPwd, err := icemaker.GenerateICECredentials()
@@ -79,14 +95,34 @@ func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, conn *
} }
func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("OnNewOffer for ICE") w.log.Debugf("OnNewOffer for ICE, serial: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Lock() w.muxAgent.Lock()
if w.agentConnecting {
w.log.Debugf("agent connection is in progress, skipping the offer")
w.muxAgent.Unlock()
return
}
if w.agent != nil { if w.agent != nil {
// backward compatibility with old clients that do not send session ID
if remoteOfferAnswer.SessionID == nil {
w.log.Debugf("agent already exists, skipping the offer") w.log.Debugf("agent already exists, skipping the offer")
w.muxAgent.Unlock() w.muxAgent.Unlock()
return return
} }
if w.remoteSessionID == *remoteOfferAnswer.SessionID {
w.log.Debugf("agent already exists and session ID matches, skipping the offer: %s", remoteOfferAnswer.SessionIDString())
w.muxAgent.Unlock()
return
}
w.log.Debugf("agent already exists, recreate the connection")
w.agentDialerCancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
// todo consider to switch to Relay connection while establishing a new ICE connection
}
var preferredCandidateTypes []ice.CandidateType var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
@@ -96,36 +132,124 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
} }
w.log.Debugf("recreate ICE agent") w.log.Debugf("recreate ICE agent")
agentCtx, agentCancel := context.WithCancel(w.ctx) dialerCtx, dialerCancel := context.WithCancel(w.ctx)
agent, err := w.reCreateAgent(agentCancel, preferredCandidateTypes) agent, err := w.reCreateAgent(dialerCancel, preferredCandidateTypes)
if err != nil { if err != nil {
w.log.Errorf("failed to recreate ICE Agent: %s", err) w.log.Errorf("failed to recreate ICE Agent: %s", err)
w.muxAgent.Unlock() w.muxAgent.Unlock()
return return
} }
w.sentExtraSrflx = false
w.agent = agent w.agent = agent
w.agentDialerCancel = dialerCancel
w.agentConnecting = true
w.muxAgent.Unlock() w.muxAgent.Unlock()
w.log.Debugf("gather candidates") go w.connect(dialerCtx, agent, remoteOfferAnswer)
err = w.agent.GatherCandidates() }
if err != nil {
w.log.Debugf("failed to gather candidates: %s", err) // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String())
if w.agent == nil {
w.log.Warnf("ICE Agent is not initialized yet")
return
}
if candidateViaRoutes(candidate, haRoutes) {
return
}
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
}
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
return w.localUfrag, w.localPwd
}
func (w *WorkerICE) InProgress() bool {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.agentConnecting
}
func (w *WorkerICE) Close() {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent == nil {
return
}
w.agentDialerCancel()
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
w.agent = nil
}
func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) {
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
}
if err := agent.OnCandidate(w.onICECandidate); err != nil {
return nil, err
}
if err := agent.OnConnectionStateChange(w.onConnectionStateChange(agent, dialerCancel)); err != nil {
return nil, err
}
if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil {
return nil, err
}
if err := agent.OnSuccessfulSelectedPairBindingResponse(w.onSuccessfulSelectedPairBindingResponse); err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil
}
func (w *WorkerICE) SessionID() ICESessionID {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.sessionID
}
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
func (w *WorkerICE) connect(ctx context.Context, agent *ice.Agent, remoteOfferAnswer *OfferAnswer) {
w.log.Debugf("gather candidates")
if err := agent.GatherCandidates(); err != nil {
w.log.Warnf("failed to gather candidates: %s", err)
w.closeAgent(agent, w.agentDialerCancel)
return return
} }
// will block until connection succeeded
// but it won't release if ICE Agent went into Disconnected or Failed state,
// so we have to cancel it with the provided context once agent detected a broken connection
w.log.Debugf("turn agent dial") w.log.Debugf("turn agent dial")
remoteConn, err := w.turnAgentDial(agentCtx, remoteOfferAnswer) remoteConn, err := w.turnAgentDial(ctx, remoteOfferAnswer)
if err != nil { if err != nil {
w.log.Debugf("failed to dial the remote peer: %s", err) w.log.Debugf("failed to dial the remote peer: %s", err)
w.closeAgent(agent, w.agentDialerCancel)
return return
} }
w.log.Debugf("agent dial succeeded") w.log.Debugf("agent dial succeeded")
pair, err := w.agent.GetSelectedCandidatePair() pair, err := agent.GetSelectedCandidatePair()
if err != nil { if err != nil {
w.closeAgent(agent, w.agentDialerCancel)
return return
} }
@@ -152,114 +276,38 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RelayedOnLocal: isRelayCandidate(pair.Local), RelayedOnLocal: isRelayCandidate(pair.Local),
} }
w.log.Debugf("on ICE conn is ready to use") w.log.Debugf("on ICE conn is ready to use")
go w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. w.log.Infof("connection succeeded with offer session: %s", remoteOfferAnswer.SessionIDString())
func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
w.muxAgent.Lock() w.muxAgent.Lock()
defer w.muxAgent.Unlock() w.agentConnecting = false
w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String()) w.lastSuccess = time.Now()
if w.agent == nil { if remoteOfferAnswer.SessionID != nil {
w.log.Warnf("ICE Agent is not initialized yet") w.remoteSessionID = *remoteOfferAnswer.SessionID
return
} }
w.muxAgent.Unlock()
if candidateViaRoutes(candidate, haRoutes) { // todo: the potential problem is a race between the onConnectionStateChange
return w.conn.onICEConnectionIsReady(selectedPriority(pair), ci)
}
err := w.agent.AddRemoteCandidate(candidate)
if err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
} }
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) { func (w *WorkerICE) closeAgent(agent *ice.Agent, cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
return w.localUfrag, w.localPwd
}
func (w *WorkerICE) Close() {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
if w.agent == nil {
return
}
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}
}
func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, candidates []ice.CandidateType) (*ice.Agent, error) {
w.sentExtraSrflx = false
agent, err := icemaker.NewAgent(w.iFaceDiscover, w.config.ICEConfig, candidates, w.localUfrag, w.localPwd)
if err != nil {
return nil, fmt.Errorf("create agent: %w", err)
}
err = agent.OnCandidate(w.onICECandidate)
if err != nil {
return nil, err
}
err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
}
w.closeAgent(agentCancel)
default:
return
}
})
if err != nil {
return nil, err
}
err = agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair)
if err != nil {
return nil, err
}
err = agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
err := w.statusRecorder.UpdateLatency(w.config.Key, p.Latency())
if err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
})
if err != nil {
return nil, fmt.Errorf("failed setting binding response callback: %w", err)
}
return agent, nil
}
func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
w.muxAgent.Lock()
defer w.muxAgent.Unlock()
cancel() cancel()
if w.agent == nil { if err := agent.Close(); err != nil {
return
}
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err) w.log.Warnf("failed to close ICE agent: %s", err)
} }
w.muxAgent.Lock()
sessionID, err := NewICESessionID()
if err != nil {
w.log.Errorf("failed to create new session ID: %s", err)
}
w.sessionID = sessionID
if w.agent == agent {
w.agent = nil w.agent = nil
w.agentConnecting = false
}
w.muxAgent.Unlock()
} }
func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
@@ -331,6 +379,32 @@ func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidat
w.config.Key) w.config.Key)
} }
func (w *WorkerICE) onConnectionStateChange(agent *ice.Agent, dialerCancel context.CancelFunc) func(ice.ConnectionState) {
return func(state ice.ConnectionState) {
w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
switch state {
case ice.ConnectionStateConnected:
w.lastKnownState = ice.ConnectionStateConnected
return
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
if w.lastKnownState == ice.ConnectionStateConnected {
w.lastKnownState = ice.ConnectionStateDisconnected
w.conn.onICEStateDisconnected()
}
w.closeAgent(agent, dialerCancel)
default:
return
}
}
}
func (w *WorkerICE) onSuccessfulSelectedPairBindingResponse(pair *ice.CandidatePair) {
if err := w.statusRecorder.UpdateLatency(w.config.Key, pair.Latency()); err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
}
}
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true return true

View File

@@ -4430,6 +4430,94 @@ func (*LogoutResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{66} return file_daemon_proto_rawDescGZIP(), []int{66}
} }
type GetFeaturesRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GetFeaturesRequest) Reset() {
*x = GetFeaturesRequest{}
mi := &file_daemon_proto_msgTypes[67]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GetFeaturesRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetFeaturesRequest) ProtoMessage() {}
func (x *GetFeaturesRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[67]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetFeaturesRequest.ProtoReflect.Descriptor instead.
func (*GetFeaturesRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{67}
}
type GetFeaturesResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
DisableProfiles bool `protobuf:"varint,1,opt,name=disable_profiles,json=disableProfiles,proto3" json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `protobuf:"varint,2,opt,name=disable_update_settings,json=disableUpdateSettings,proto3" json:"disable_update_settings,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *GetFeaturesResponse) Reset() {
*x = GetFeaturesResponse{}
mi := &file_daemon_proto_msgTypes[68]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *GetFeaturesResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*GetFeaturesResponse) ProtoMessage() {}
func (x *GetFeaturesResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[68]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use GetFeaturesResponse.ProtoReflect.Descriptor instead.
func (*GetFeaturesResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{68}
}
func (x *GetFeaturesResponse) GetDisableProfiles() bool {
if x != nil {
return x.DisableProfiles
}
return false
}
func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool {
if x != nil {
return x.DisableUpdateSettings
}
return false
}
type PortInfo_Range struct { type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -4440,7 +4528,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() { func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{} *x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[68] mi := &file_daemon_proto_msgTypes[70]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@@ -4452,7 +4540,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {} func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[68] mi := &file_daemon_proto_msgTypes[70]
if x != nil { if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@@ -4872,7 +4960,11 @@ const file_daemon_proto_rawDesc = "" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" + "\f_profileNameB\v\n" +
"\t_username\"\x10\n" + "\t_username\"\x10\n" +
"\x0eLogoutResponse*b\n" + "\x0eLogoutResponse\"\x14\n" +
"\x12GetFeaturesRequest\"x\n" +
"\x13GetFeaturesResponse\x12)\n" +
"\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" +
"\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings*b\n" +
"\bLogLevel\x12\v\n" + "\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" + "\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" +
@@ -4881,7 +4973,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" + "\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" + "\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xc5\x0f\n" + "\x05TRACE\x10\a2\x8f\x10\n" +
"\rDaemonService\x126\n" + "\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -4912,7 +5004,8 @@ const file_daemon_proto_rawDesc = "" +
"\rRemoveProfile\x12\x1c.daemon.RemoveProfileRequest\x1a\x1d.daemon.RemoveProfileResponse\"\x00\x12K\n" + "\rRemoveProfile\x12\x1c.daemon.RemoveProfileRequest\x1a\x1d.daemon.RemoveProfileResponse\"\x00\x12K\n" +
"\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" + "\fListProfiles\x12\x1b.daemon.ListProfilesRequest\x1a\x1c.daemon.ListProfilesResponse\"\x00\x12W\n" +
"\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" + "\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" +
"\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00B\bZ\x06/protob\x06proto3" "\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" +
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00B\bZ\x06/protob\x06proto3"
var ( var (
file_daemon_proto_rawDescOnce sync.Once file_daemon_proto_rawDescOnce sync.Once
@@ -4927,7 +5020,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
} }
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3) var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 70) var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 72)
var file_daemon_proto_goTypes = []any{ var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel (LogLevel)(0), // 0: daemon.LogLevel
(SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity (SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity
@@ -4999,18 +5092,20 @@ var file_daemon_proto_goTypes = []any{
(*GetActiveProfileResponse)(nil), // 67: daemon.GetActiveProfileResponse (*GetActiveProfileResponse)(nil), // 67: daemon.GetActiveProfileResponse
(*LogoutRequest)(nil), // 68: daemon.LogoutRequest (*LogoutRequest)(nil), // 68: daemon.LogoutRequest
(*LogoutResponse)(nil), // 69: daemon.LogoutResponse (*LogoutResponse)(nil), // 69: daemon.LogoutResponse
nil, // 70: daemon.Network.ResolvedIPsEntry (*GetFeaturesRequest)(nil), // 70: daemon.GetFeaturesRequest
(*PortInfo_Range)(nil), // 71: daemon.PortInfo.Range (*GetFeaturesResponse)(nil), // 71: daemon.GetFeaturesResponse
nil, // 72: daemon.SystemEvent.MetadataEntry nil, // 72: daemon.Network.ResolvedIPsEntry
(*durationpb.Duration)(nil), // 73: google.protobuf.Duration (*PortInfo_Range)(nil), // 73: daemon.PortInfo.Range
(*timestamppb.Timestamp)(nil), // 74: google.protobuf.Timestamp nil, // 74: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 75: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 76: google.protobuf.Timestamp
} }
var file_daemon_proto_depIdxs = []int32{ var file_daemon_proto_depIdxs = []int32{
73, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 75, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus 22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
74, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp 76, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
74, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp 76, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
73, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration 75, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState 18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState 17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
@@ -5019,8 +5114,8 @@ var file_daemon_proto_depIdxs = []int32{
21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent 52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent
28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network 28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
70, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry 72, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
71, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range 73, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo 29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo 29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule 30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -5031,10 +5126,10 @@ var file_daemon_proto_depIdxs = []int32{
49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage 49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity 1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category 2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
74, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp 76, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
72, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry 74, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent 52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
73, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 75, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile 65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
@@ -5064,35 +5159,37 @@ var file_daemon_proto_depIdxs = []int32{
63, // 55: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest 63, // 55: daemon.DaemonService.ListProfiles:input_type -> daemon.ListProfilesRequest
66, // 56: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest 66, // 56: daemon.DaemonService.GetActiveProfile:input_type -> daemon.GetActiveProfileRequest
68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest 68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
5, // 58: daemon.DaemonService.Login:output_type -> daemon.LoginResponse 70, // 58: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
7, // 59: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse 5, // 59: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
9, // 60: daemon.DaemonService.Up:output_type -> daemon.UpResponse 7, // 60: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
11, // 61: daemon.DaemonService.Status:output_type -> daemon.StatusResponse 9, // 61: daemon.DaemonService.Up:output_type -> daemon.UpResponse
13, // 62: daemon.DaemonService.Down:output_type -> daemon.DownResponse 11, // 62: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
15, // 63: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse 13, // 63: daemon.DaemonService.Down:output_type -> daemon.DownResponse
24, // 64: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse 15, // 64: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
26, // 65: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse 24, // 65: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
26, // 66: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse 26, // 66: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 67: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse 26, // 67: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
33, // 68: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse 31, // 68: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
35, // 69: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse 33, // 69: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
37, // 70: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse 35, // 70: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
40, // 71: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse 37, // 71: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
42, // 72: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse 40, // 72: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
44, // 73: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse 42, // 73: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
46, // 74: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse 44, // 74: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
50, // 75: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse 46, // 75: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
52, // 76: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent 50, // 76: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
54, // 77: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse 52, // 77: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
56, // 78: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse 54, // 78: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
58, // 79: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse 56, // 79: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
60, // 80: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse 58, // 80: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
62, // 81: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse 60, // 81: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
64, // 82: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse 62, // 82: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
67, // 83: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse 64, // 83: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
69, // 84: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse 67, // 84: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
58, // [58:85] is the sub-list for method output_type 69, // 85: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
31, // [31:58] is the sub-list for method input_type 71, // 86: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
59, // [59:87] is the sub-list for method output_type
31, // [31:59] is the sub-list for method input_type
31, // [31:31] is the sub-list for extension type_name 31, // [31:31] is the sub-list for extension type_name
31, // [31:31] is the sub-list for extension extendee 31, // [31:31] is the sub-list for extension extendee
0, // [0:31] is the sub-list for field type_name 0, // [0:31] is the sub-list for field type_name
@@ -5120,7 +5217,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 3, NumEnums: 3,
NumMessages: 70, NumMessages: 72,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@@ -82,6 +82,8 @@ service DaemonService {
// Logout disconnects from the network and deletes the peer from the management server // Logout disconnects from the network and deletes the peer from the management server
rpc Logout(LogoutRequest) returns (LogoutResponse) {} rpc Logout(LogoutRequest) returns (LogoutResponse) {}
rpc GetFeatures(GetFeaturesRequest) returns (GetFeaturesResponse) {}
} }
@@ -625,3 +627,10 @@ message LogoutRequest {
} }
message LogoutResponse {} message LogoutResponse {}
message GetFeaturesRequest{}
message GetFeaturesResponse{
bool disable_profiles = 1;
bool disable_update_settings = 2;
}

View File

@@ -63,6 +63,7 @@ type DaemonServiceClient interface {
GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error) GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error)
// Logout disconnects from the network and deletes the peer from the management server // Logout disconnects from the network and deletes the peer from the management server
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error)
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
} }
type daemonServiceClient struct { type daemonServiceClient struct {
@@ -339,6 +340,15 @@ func (c *daemonServiceClient) Logout(ctx context.Context, in *LogoutRequest, opt
return out, nil return out, nil
} }
func (c *daemonServiceClient) GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error) {
out := new(GetFeaturesResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetFeatures", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service. // DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer // All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility // for forward compatibility
@@ -388,6 +398,7 @@ type DaemonServiceServer interface {
GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error) GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error)
// Logout disconnects from the network and deletes the peer from the management server // Logout disconnects from the network and deletes the peer from the management server
Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error)
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
mustEmbedUnimplementedDaemonServiceServer() mustEmbedUnimplementedDaemonServiceServer()
} }
@@ -476,6 +487,9 @@ func (UnimplementedDaemonServiceServer) GetActiveProfile(context.Context, *GetAc
func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) { func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented") return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented")
} }
func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetFeatures not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -978,6 +992,24 @@ func _DaemonService_Logout_Handler(srv interface{}, ctx context.Context, dec fun
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _DaemonService_GetFeatures_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetFeaturesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).GetFeatures(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/GetFeatures",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetFeatures(ctx, req.(*GetFeaturesRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@@ -1089,6 +1121,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "Logout", MethodName: "Logout",
Handler: _DaemonService_Logout_Handler, Handler: _DaemonService_Logout_Handler,
}, },
{
MethodName: "GetFeatures",
Handler: _DaemonService_GetFeatures_Handler,
},
}, },
Streams: []grpc.StreamDesc{ Streams: []grpc.StreamDesc{
{ {

View File

@@ -48,6 +48,7 @@ const (
errRestoreResidualState = "failed to restore residual state: %v" errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled" errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
) )
var ErrServiceNotUp = errors.New("service is not up") var ErrServiceNotUp = errors.New("service is not up")
@@ -76,6 +77,7 @@ type Server struct {
profileManager *profilemanager.ServiceManager profileManager *profilemanager.ServiceManager
profilesDisabled bool profilesDisabled bool
updateSettingsDisabled bool
} }
type oauthAuthFlow struct { type oauthAuthFlow struct {
@@ -86,7 +88,7 @@ type oauthAuthFlow struct {
} }
// New server instance constructor. // New server instance constructor.
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool) *Server { func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool) *Server {
return &Server{ return &Server{
rootCtx: ctx, rootCtx: ctx,
logFile: logFile, logFile: logFile,
@@ -94,6 +96,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
statusRecorder: peer.NewRecorder(""), statusRecorder: peer.NewRecorder(""),
profileManager: profilemanager.NewServiceManager(configFile), profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled, profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled,
} }
} }
@@ -322,8 +325,8 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if s.checkProfilesDisabled() { if s.checkUpdateSettingsDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled) return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
} }
profState := profilemanager.ActiveProfileState{ profState := profilemanager.ActiveProfileState{
@@ -1330,10 +1333,31 @@ func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfi
}, nil }, nil
} }
// GetFeatures returns the features supported by the daemon.
func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
}
return features, nil
}
func (s *Server) checkProfilesDisabled() bool { func (s *Server) checkProfilesDisabled() bool {
// Check if the environment variable is set to disable profiles // Check if the environment variable is set to disable profiles
if s.profilesDisabled { if s.profilesDisabled {
log.Warn("Profiles are disabled via NB_DISABLE_PROFILES environment variable") return true
}
return false
}
func (s *Server) checkUpdateSettingsDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.updateSettingsDisabled {
return true return true
} }

View File

@@ -14,6 +14,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -24,7 +25,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
daemonProto "github.com/netbirdio/netbird/client/proto" daemonProto "github.com/netbirdio/netbird/client/proto"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
@@ -33,6 +33,7 @@ import (
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server" signalServer "github.com/netbirdio/netbird/signal/server"
) )
@@ -94,7 +95,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err) t.Fatalf("failed to set active profile state: %v", err)
} }
s := New(ctx, "debug", "", false) s := New(ctx, "debug", "", false, false)
s.config = config s.config = config
@@ -151,7 +152,7 @@ func TestServer_Up(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err) t.Fatalf("failed to set active profile state: %v", err)
} }
s := New(ctx, "console", "", false) s := New(ctx, "console", "", false, false)
err = s.Start() err = s.Start()
require.NoError(t, err) require.NoError(t, err)
@@ -227,7 +228,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err) t.Fatalf("failed to set active profile state: %v", err)
} }
s := New(ctx, "console", "", false) s := New(ctx, "console", "", false, false)
err = s.Start() err = s.Start()
require.NoError(t, err) require.NoError(t, err)
@@ -302,13 +303,14 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err

View File

@@ -392,6 +392,16 @@ func (s *serviceClient) updateIcon() {
} }
func (s *serviceClient) showSettingsUI() { func (s *serviceClient) showSettingsUI() {
// Check if update settings are disabled by daemon
features, err := s.getFeatures()
if err != nil {
log.Errorf("failed to get features from daemon: %v", err)
// Continue with default behavior if features can't be retrieved
} else if features != nil && features.DisableUpdateSettings {
log.Warn("Update settings are disabled by daemon")
return
}
// add settings window UI elements. // add settings window UI elements.
s.wSettings = s.app.NewWindow("NetBird Settings") s.wSettings = s.app.NewWindow("NetBird Settings")
s.wSettings.SetOnClosed(s.cancel) s.wSettings.SetOnClosed(s.cancel)
@@ -447,6 +457,17 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
}, },
SubmitText: "Save", SubmitText: "Save",
OnSubmit: func() { OnSubmit: func() {
// Check if update settings are disabled by daemon
features, err := s.getFeatures()
if err != nil {
log.Errorf("failed to get features from daemon: %v", err)
// Continue with default behavior if features can't be retrieved
} else if features != nil && features.DisableUpdateSettings {
log.Warn("Configuration updates are disabled by daemon")
dialog.ShowError(fmt.Errorf("Configuration updates are disabled by daemon"), s.wSettings)
return
}
if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey { if s.iPreSharedKey.Text != "" && s.iPreSharedKey.Text != censoredPreSharedKey {
// validate preSharedKey if it added // validate preSharedKey if it added
if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil { if _, err := wgtypes.ParseKey(s.iPreSharedKey.Text); err != nil {
@@ -836,6 +857,20 @@ func (s *serviceClient) onTrayReady() {
s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr) s.mCreateDebugBundle = s.mSettings.AddSubMenuItem("Create Debug Bundle", debugBundleMenuDescr)
s.loadSettings() s.loadSettings()
// Disable settings menu if update settings are disabled by daemon
features, err := s.getFeatures()
if err != nil {
log.Errorf("failed to get features from daemon: %v", err)
// Continue with default behavior if features can't be retrieved
} else {
if features != nil && features.DisableUpdateSettings {
s.setSettingsEnabled(false)
}
if features != nil && features.DisableProfiles {
s.mProfile.setEnabled(false)
}
}
s.exitNodeMu.Lock() s.exitNodeMu.Lock()
s.mExitNode = systray.AddMenuItem("Exit Node", exitNodeMenuDescr) s.mExitNode = systray.AddMenuItem("Exit Node", exitNodeMenuDescr)
s.mExitNode.Disable() s.mExitNode.Disable()
@@ -876,6 +911,10 @@ func (s *serviceClient) onTrayReady() {
if err != nil { if err != nil {
log.Errorf("error while updating status: %v", err) log.Errorf("error while updating status: %v", err)
} }
// Check features periodically to handle daemon restarts
s.checkAndUpdateFeatures()
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
} }
}() }()
@@ -948,6 +987,59 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
return s.conn, nil return s.conn, nil
} }
// setSettingsEnabled enables or disables the settings menu based on the provided state
func (s *serviceClient) setSettingsEnabled(enabled bool) {
if s.mSettings != nil {
if enabled {
s.mSettings.Enable()
s.mSettings.SetTooltip(settingsMenuDescr)
} else {
s.mSettings.Hide()
s.mSettings.SetTooltip("Settings are disabled by daemon")
}
}
}
// checkAndUpdateFeatures checks the current features and updates the UI accordingly
func (s *serviceClient) checkAndUpdateFeatures() {
features, err := s.getFeatures()
if err != nil {
log.Errorf("failed to get features from daemon: %v", err)
return
}
// Update settings menu based on current features
if features != nil && features.DisableUpdateSettings {
s.setSettingsEnabled(false)
} else {
s.setSettingsEnabled(true)
}
// Update profile menu based on current features
if s.mProfile != nil {
if features != nil && features.DisableProfiles {
s.mProfile.setEnabled(false)
} else {
s.mProfile.setEnabled(true)
}
}
}
// getFeatures from the daemon to determine which features are enabled/disabled.
func (s *serviceClient) getFeatures() (*proto.GetFeaturesResponse, error) {
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
return nil, fmt.Errorf("get client for features: %w", err)
}
features, err := conn.GetFeatures(s.ctx, &proto.GetFeaturesRequest{})
if err != nil {
return nil, fmt.Errorf("get features from daemon: %w", err)
}
return features, nil
}
// getSrvConfig from the service to show it in the settings window. // getSrvConfig from the service to show it in the settings window.
func (s *serviceClient) getSrvConfig() { func (s *serviceClient) getSrvConfig() {
s.managementURL = profilemanager.DefaultManagementURL s.managementURL = profilemanager.DefaultManagementURL

View File

@@ -654,6 +654,19 @@ func (p *profileMenu) clear(profiles []Profile) {
} }
} }
// setEnabled enables or disables the profile menu based on the provided state
func (p *profileMenu) setEnabled(enabled bool) {
if p.profileMenuItem != nil {
if enabled {
p.profileMenuItem.Enable()
p.profileMenuItem.SetTooltip("")
} else {
p.profileMenuItem.Hide()
p.profileMenuItem.SetTooltip("Profiles are disabled by daemon")
}
}
}
func (p *profileMenu) updateMenu() { func (p *profileMenu) updateMenu() {
// check every second // check every second
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
@@ -662,7 +675,6 @@ func (p *profileMenu) updateMenu() {
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
// get profilesList // get profilesList
profiles, err := p.getProfiles() profiles, err := p.getProfiles()
if err != nil { if err != nil {

2
go.mod
View File

@@ -63,7 +63,7 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f h1:YmqNWdRbeVn1lSpkLzIiFHX2cndRuaVYyynx2ibrOtg= github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e h1:S85laGfx1UP+nmRF9smP6/TY965kLWz41PbBK1TX8g0=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250805121557-5f225a973d1f/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/management-integrations/integrations v0.0.0-20250812185008-dfc66fa49a2e/go.mod h1:Jjve0+eUjOLKL3PJtAhjfM2iJ0SxWio5elHqlV1ymP8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -0,0 +1,2 @@
GRANT SYSTEM_VARIABLES_ADMIN ON *.* TO 'netbird'@'%';
FLUSH PRIVILEGES;

View File

@@ -34,6 +34,7 @@ import (
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/peers" "github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
@@ -45,7 +46,6 @@ import (
"github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/auth"
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
nbhttp "github.com/netbirdio/netbird/management/server/http" nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/metrics" "github.com/netbirdio/netbird/management/server/metrics"
@@ -220,7 +220,8 @@ var (
return fmt.Errorf("build default manager: %v", err) return fmt.Errorf("build default manager: %v", err)
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager) groupsManager := groups.NewManager(store, permissionsManager, accountManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsManager, groupsManager)
trustedPeers := config.ReverseProxy.TrustedPeers trustedPeers := config.ReverseProxy.TrustedPeers
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")} defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
@@ -277,7 +278,6 @@ var (
config.GetAuthAudiences(), config.GetAuthAudiences(),
config.HttpConfig.IdpSignKeyRefreshEnabled) config.HttpConfig.IdpSignKeyRefreshEnabled)
groupsManager := groups.NewManager(store, permissionsManager, accountManager)
resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager) resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager)
routersManager := routers.NewManager(store, permissionsManager, accountManager) routersManager := routers.NewManager(store, permissionsManager, accountManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager) networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)

View File

@@ -21,6 +21,7 @@ type Manager interface {
AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error
AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error) AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error)
RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error) RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error)
GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error)
} }
type managerImpl struct { type managerImpl struct {
@@ -142,6 +143,10 @@ func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transa
return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID) return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID)
} }
func (m *managerImpl) GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error) {
return m.store.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID)
}
func ToGroupsInfoMap(groups []*types.Group, idCount int) map[string][]api.GroupMinimum { func ToGroupsInfoMap(groups []*types.Group, idCount int) map[string][]api.GroupMinimum {
groupsInfoMap := make(map[string][]api.GroupMinimum, idCount) groupsInfoMap := make(map[string][]api.GroupMinimum, idCount)
groupsChecked := make(map[string]struct{}, len(groups)) // not sure why this is needed (left over from old implementation) groupsChecked := make(map[string]struct{}, len(groups)) // not sure why this is needed (left over from old implementation)
@@ -202,6 +207,10 @@ func (m *mockManager) RemoveResourceFromGroupInTransaction(ctx context.Context,
}, nil }, nil
} }
func (m *mockManager) GetPeerGroupIDs(ctx context.Context, accountID, peerID string) ([]string, error) {
return []string{}, nil
}
func NewManagerMock() Manager { func NewManagerMock() Manager {
return &mockManager{} return &mockManager{}
} }

View File

@@ -662,7 +662,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
} }
} }
func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings) *proto.SyncResponse { func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string) *proto.SyncResponse {
response := &proto.SyncResponse{ response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
@@ -674,7 +674,7 @@ func toSyncResponse(ctx context.Context, config *types.Config, peer *nbpeer.Peer
} }
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings) nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, nbConfig, extraSettings) extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
response.NetbirdConfig = extendedConfig response.NetbirdConfig = extendedConfig
response.NetworkMap.PeerConfig = response.PeerConfig response.NetworkMap.PeerConfig = response.PeerConfig
@@ -750,7 +750,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
return status.Errorf(codes.Internal, "error handling request") return status.Errorf(codes.Internal, "error handling request")
} }
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra) peerGroups, err := getPeerGroupIDs(ctx, s.accountManager.GetStore(), peer.AccountID, peer.ID)
if err != nil {
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
}
plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil { if err != nil {

View File

@@ -199,6 +199,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
settings.Extra = &types.ExtraSettings{ settings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled, PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled, FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled, FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
} }
} }
@@ -327,6 +328,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
apiSettings.Extra = &api.AccountExtraSettings{ apiSettings.Extra = &api.AccountExtraSettings{
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled, PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled, NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled,
NetworkTrafficLogsGroups: settings.Extra.FlowGroups,
NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled, NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled,
} }
} }

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -446,6 +447,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
Return(&types.ExtraSettings{}, nil). Return(&types.ExtraSettings{}, nil).
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
@@ -455,7 +457,7 @@ func startManagementForTest(t *testing.T, testFile string, config *types.Config)
return nil, nil, "", cleanup, err return nil, nil, "", cleanup, err
} }
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
ephemeralMgr := NewEphemeralManager(store, accountManager) ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})

View File

@@ -23,6 +23,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/shared/management/proto" mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -216,7 +217,8 @@ func startServer(
t.Fatalf("failed creating an account manager: %v", err) t.Fatalf("failed creating an account manager: %v", err)
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) groupsManager := groups.NewManager(str, permissionsManager, accountManager)
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer( mgmtServer, err := server.NewServer(
context.Background(), context.Background(),
config, config,

View File

@@ -1275,8 +1275,9 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
} }
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start)) am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now() start = time.Now()
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups))
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
@@ -1386,7 +1387,8 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return return
} }
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings) peerGroups := account.GetPeerGroups(peerId)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups))
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
} }

View File

@@ -1164,7 +1164,7 @@ func TestToSyncResponse(t *testing.T) {
} }
dnsCache := &DNSConfigCache{} dnsCache := &DNSConfigCache{}
accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true}
response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil) response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{})
assert.NotNil(t, response) assert.NotNil(t, response)
// assert peer config // assert peer config

View File

@@ -68,6 +68,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
// Once we migrate the peer approval to settings manager this merging is obsolete // Once we migrate the peer approval to settings manager this merging is obsolete
if settings.Extra != nil { if settings.Extra != nil {
settings.Extra.FlowEnabled = extraSettings.FlowEnabled settings.Extra.FlowEnabled = extraSettings.FlowEnabled
settings.Extra.FlowGroups = extraSettings.FlowGroups
settings.Extra.FlowPacketCounterEnabled = extraSettings.FlowPacketCounterEnabled settings.Extra.FlowPacketCounterEnabled = extraSettings.FlowPacketCounterEnabled
settings.Extra.FlowENCollectionEnabled = extraSettings.FlowENCollectionEnabled settings.Extra.FlowENCollectionEnabled = extraSettings.FlowENCollectionEnabled
settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled
@@ -93,6 +94,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
} }
settings.Extra.FlowEnabled = extraSettings.FlowEnabled settings.Extra.FlowEnabled = extraSettings.FlowEnabled
settings.Extra.FlowGroups = extraSettings.FlowGroups
return settings.Extra, nil return settings.Extra, nil
} }

View File

@@ -11,13 +11,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/management/proto" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac" auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2" authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
) )
const defaultDuration = 12 * time.Hour const defaultDuration = 12 * time.Hour
@@ -39,13 +39,14 @@ type TimeBasedAuthSecretsManager struct {
relayHmacToken *authv2.Generator relayHmacToken *authv2.Generator
updateManager *PeersUpdateManager updateManager *PeersUpdateManager
settingsManager settings.Manager settingsManager settings.Manager
groupsManager groups.Manager
turnCancelMap map[string]chan struct{} turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{} relayCancelMap map[string]chan struct{}
} }
type Token auth.Token type Token auth.Token
func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager) *TimeBasedAuthSecretsManager { func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *types.TURNConfig, relayCfg *types.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
mgr := &TimeBasedAuthSecretsManager{ mgr := &TimeBasedAuthSecretsManager{
updateManager: updateManager, updateManager: updateManager,
turnCfg: turnCfg, turnCfg: turnCfg,
@@ -53,6 +54,7 @@ func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *
turnCancelMap: make(map[string]chan struct{}), turnCancelMap: make(map[string]chan struct{}),
relayCancelMap: make(map[string]chan struct{}), relayCancelMap: make(map[string]chan struct{}),
settingsManager: settingsManager, settingsManager: settingsManager,
groupsManager: groupsManager,
} }
if turnCfg != nil { if turnCfg != nil {
@@ -258,6 +260,11 @@ func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, p
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err) log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
} }
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, update.NetbirdConfig, extraSettings) peerGroups, err := m.groupsManager.GetPeerGroupIDs(ctx, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peer groups: %v", err)
}
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, peerGroups, update.NetbirdConfig, extraSettings)
update.NetbirdConfig = extendedConfig update.NetbirdConfig = extendedConfig
} }

View File

@@ -13,9 +13,10 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -40,13 +41,14 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*types.Host{TurnTestHost}, Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager) }, rc, settingsMockManager, groupsManager)
turnCredentials, err := tested.GenerateTurnToken() turnCredentials, err := tested.GenerateTurnToken()
require.NoError(t, err) require.NoError(t, err)
@@ -91,13 +93,14 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*types.Host{TurnTestHost}, Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager) }, rc, settingsMockManager, groupsManager)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@@ -193,13 +196,14 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{ tested := NewTimeBasedAuthSecretsManager(peersManager, &types.TURNConfig{
CredentialsTTL: ttl, CredentialsTTL: ttl,
Secret: secret, Secret: secret,
Turns: []*types.Host{TurnTestHost}, Turns: []*types.Host{TurnTestHost},
TimeBasedCredentials: true, TimeBasedCredentials: true,
}, rc, settingsMockManager) }, rc, settingsMockManager, groupsManager)
tested.SetupRefresh(context.Background(), "someAccountID", peer) tested.SetupRefresh(context.Background(), "someAccountID", peer)
if _, ok := tested.turnCancelMap[peer]; !ok { if _, ok := tested.turnCancelMap[peer]; !ok {

View File

@@ -2,6 +2,7 @@ package types
import ( import (
"net/netip" "net/netip"
"slices"
"time" "time"
) )
@@ -88,6 +89,7 @@ type ExtraSettings struct {
IntegratedValidatorGroups []string `gorm:"serializer:json"` IntegratedValidatorGroups []string `gorm:"serializer:json"`
FlowEnabled bool `gorm:"-"` FlowEnabled bool `gorm:"-"`
FlowGroups []string `gorm:"-"`
FlowPacketCounterEnabled bool `gorm:"-"` FlowPacketCounterEnabled bool `gorm:"-"`
FlowENCollectionEnabled bool `gorm:"-"` FlowENCollectionEnabled bool `gorm:"-"`
FlowDnsCollectionEnabled bool `gorm:"-"` FlowDnsCollectionEnabled bool `gorm:"-"`
@@ -95,13 +97,12 @@ type ExtraSettings struct {
// Copy copies the ExtraSettings struct // Copy copies the ExtraSettings struct
func (e *ExtraSettings) Copy() *ExtraSettings { func (e *ExtraSettings) Copy() *ExtraSettings {
var cpGroup []string
return &ExtraSettings{ return &ExtraSettings{
PeerApprovalEnabled: e.PeerApprovalEnabled, PeerApprovalEnabled: e.PeerApprovalEnabled,
IntegratedValidatorGroups: append(cpGroup, e.IntegratedValidatorGroups...), IntegratedValidatorGroups: slices.Clone(e.IntegratedValidatorGroups),
IntegratedValidator: e.IntegratedValidator, IntegratedValidator: e.IntegratedValidator,
FlowEnabled: e.FlowEnabled, FlowEnabled: e.FlowEnabled,
FlowGroups: slices.Clone(e.FlowGroups),
FlowPacketCounterEnabled: e.FlowPacketCounterEnabled, FlowPacketCounterEnabled: e.FlowPacketCounterEnabled,
FlowENCollectionEnabled: e.FlowENCollectionEnabled, FlowENCollectionEnabled: e.FlowENCollectionEnabled,
FlowDnsCollectionEnabled: e.FlowDnsCollectionEnabled, FlowDnsCollectionEnabled: e.FlowDnsCollectionEnabled,

33
relay/cmd/pprof.go Normal file
View File

@@ -0,0 +1,33 @@
//go:build pprof
// +build pprof
package cmd
import (
"net/http"
_ "net/http/pprof"
"os"
log "github.com/sirupsen/logrus"
)
func init() {
addr := pprofAddr()
go pprof(addr)
}
func pprofAddr() string {
listenAddr := os.Getenv("NB_PPROF_ADDR")
if listenAddr == "" {
return "localhost:6969"
}
return listenAddr
}
func pprof(listenAddr string) {
log.Infof("listening pprof on: %s\n", listenAddr)
if err := http.ListenAndServe(listenAddr, nil); err != nil {
log.Fatalf("Failed to start pprof: %v", err)
}
}

View File

@@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"sync"
"syscall" "syscall"
"time" "time"
@@ -17,8 +18,9 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/relay/auth" "github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -40,6 +42,7 @@ type Config struct {
AuthSecret string AuthSecret string
LogLevel string LogLevel string
LogFile string LogFile string
HealthcheckListenAddress string
} }
func (c Config) Validate() error { func (c Config) Validate() error {
@@ -87,6 +90,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret") rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level") rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file") rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
setFlagsFromEnvVars(rootCmd) setFlagsFromEnvVars(rootCmd)
} }
@@ -102,6 +106,7 @@ func waitForExitSignal() {
} }
func execute(cmd *cobra.Command, args []string) error { func execute(cmd *cobra.Command, args []string) error {
wg := sync.WaitGroup{}
err := cobraConfig.Validate() err := cobraConfig.Validate()
if err != nil { if err != nil {
log.Debugf("invalid config: %s", err) log.Debugf("invalid config: %s", err)
@@ -120,7 +125,9 @@ func execute(cmd *cobra.Command, args []string) error {
return fmt.Errorf("setup metrics: %v", err) return fmt.Errorf("setup metrics: %v", err)
} }
wg.Add(1)
go func() { go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint) log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start metrics server: %v", err) log.Fatalf("Failed to start metrics server: %v", err)
@@ -154,12 +161,31 @@ func execute(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to create relay server: %v", err) return fmt.Errorf("failed to create relay server: %v", err)
} }
log.Infof("server will be available on: %s", srv.InstanceURL()) log.Infof("server will be available on: %s", srv.InstanceURL())
wg.Add(1)
go func() { go func() {
defer wg.Done()
if err := srv.Listen(srvListenerCfg); err != nil { if err := srv.Listen(srvListenerCfg); err != nil {
log.Fatalf("failed to bind server: %s", err) log.Fatalf("failed to bind server: %s", err)
} }
}() }()
hCfg := healthcheck.Config{
ListenAddress: cobraConfig.HealthcheckListenAddress,
ServiceChecker: srv,
}
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
log.Debugf("failed to create healthcheck server: %v", err)
return fmt.Errorf("failed to create healthcheck server: %v", err)
}
wg.Add(1)
go func() {
defer wg.Done()
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start healthcheck server: %v", err)
}
}()
// it will block until exit signal // it will block until exit signal
waitForExitSignal() waitForExitSignal()
@@ -167,6 +193,10 @@ func execute(cmd *cobra.Command, args []string) error {
defer cancel() defer cancel()
var shutDownErrors error var shutDownErrors error
if err := httpHealthcheck.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err))
}
if err := srv.Shutdown(ctx); err != nil { if err := srv.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err)) shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
} }
@@ -175,6 +205,8 @@ func execute(cmd *cobra.Command, args []string) error {
if err := metricsServer.Shutdown(ctx); err != nil { if err := metricsServer.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err)) shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
} }
wg.Wait()
return shutDownErrors return shutDownErrors
} }

View File

@@ -0,0 +1,195 @@
package healthcheck
import (
"context"
"encoding/json"
"errors"
"net"
"net/http"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
"github.com/netbirdio/netbird/relay/server/listener/quic"
"github.com/netbirdio/netbird/relay/server/listener/ws"
)
const (
statusHealthy = "healthy"
statusUnhealthy = "unhealthy"
path = "/health"
cacheTTL = 3 * time.Second // Cache TTL for health status
)
type ServiceChecker interface {
ListenerProtocols() []protocol.Protocol
ListenAddress() string
}
type HealthStatus struct {
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
Listeners []protocol.Protocol `json:"listeners"`
CertificateValid bool `json:"certificate_valid"`
}
type Config struct {
ListenAddress string
ServiceChecker ServiceChecker
}
type Server struct {
config Config
httpServer *http.Server
cacheMu sync.Mutex
cacheStatus *HealthStatus
}
func NewServer(config Config) (*Server, error) {
mux := http.NewServeMux()
if config.ServiceChecker == nil {
return nil, errors.New("service checker is required")
}
server := &Server{
config: config,
httpServer: &http.Server{
Addr: config.ListenAddress,
Handler: mux,
ReadTimeout: 5 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 15 * time.Second,
},
}
mux.HandleFunc(path, server.handleHealthcheck)
return server, nil
}
func (s *Server) ListenAndServe() error {
log.Infof("starting healthcheck server on: http://%s%s", dialAddress(s.config.ListenAddress), path)
return s.httpServer.ListenAndServe()
}
// Shutdown gracefully shuts down the healthcheck server
func (s *Server) Shutdown(ctx context.Context) error {
log.Info("Shutting down healthcheck server")
return s.httpServer.Shutdown(ctx)
}
func (s *Server) handleHealthcheck(w http.ResponseWriter, _ *http.Request) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
var (
status *HealthStatus
ok bool
)
// Cache check
s.cacheMu.Lock()
status = s.cacheStatus
s.cacheMu.Unlock()
if status != nil && time.Since(status.Timestamp) <= cacheTTL {
ok = status.Status == statusHealthy
} else {
status, ok = s.getHealthStatus(ctx)
// Update cache
s.cacheMu.Lock()
s.cacheStatus = status
s.cacheMu.Unlock()
}
w.Header().Set("Content-Type", "application/json")
if ok {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusServiceUnavailable)
}
encoder := json.NewEncoder(w)
if err := encoder.Encode(status); err != nil {
log.Errorf("Failed to encode healthcheck response: %v", err)
}
}
func (s *Server) getHealthStatus(ctx context.Context) (*HealthStatus, bool) {
healthy := true
status := &HealthStatus{
Timestamp: time.Now(),
Status: statusHealthy,
CertificateValid: true,
}
listeners, ok := s.validateListeners()
if !ok {
status.Status = statusUnhealthy
healthy = false
}
status.Listeners = listeners
if ok := s.validateCertificate(ctx); !ok {
status.Status = statusUnhealthy
status.CertificateValid = false
healthy = false
}
return status, healthy
}
func (s *Server) validateListeners() ([]protocol.Protocol, bool) {
listeners := s.config.ServiceChecker.ListenerProtocols()
if len(listeners) == 0 {
return nil, false
}
return listeners, true
}
func (s *Server) validateCertificate(ctx context.Context) bool {
listenAddress := s.config.ServiceChecker.ListenAddress()
if listenAddress == "" {
log.Warn("listen address is empty")
return false
}
dAddr := dialAddress(listenAddress)
for _, proto := range s.config.ServiceChecker.ListenerProtocols() {
switch proto {
case ws.Proto:
if err := dialWS(ctx, dAddr); err != nil {
log.Errorf("failed to dial WebSocket listener: %v", err)
return false
}
case quic.Proto:
if err := dialQUIC(ctx, dAddr); err != nil {
log.Errorf("failed to dial QUIC listener: %v", err)
return false
}
default:
log.Warnf("unknown protocol for healthcheck: %s", proto)
return false
}
}
return true
}
func dialAddress(listenAddress string) string {
host, port, err := net.SplitHostPort(listenAddress)
if err != nil {
return listenAddress // fallback, might be invalid for dialing
}
if host == "" || host == "::" || host == "0.0.0.0" {
host = "0.0.0.0"
}
return net.JoinHostPort(host, port)
}

31
relay/healthcheck/quic.go Normal file
View File

@@ -0,0 +1,31 @@
package healthcheck
import (
"context"
"crypto/tls"
"fmt"
"time"
"github.com/quic-go/quic-go"
tlsnb "github.com/netbirdio/netbird/shared/relay/tls"
)
func dialQUIC(ctx context.Context, address string) error {
tlsConfig := &tls.Config{
InsecureSkipVerify: false, // Keep certificate validation enabled
NextProtos: []string{tlsnb.NBalpn},
}
conn, err := quic.DialAddr(ctx, address, tlsConfig, &quic.Config{
MaxIdleTimeout: 30 * time.Second,
KeepAlivePeriod: 10 * time.Second,
EnableDatagrams: true,
})
if err != nil {
return fmt.Errorf("failed to connect to QUIC server: %w", err)
}
_ = conn.CloseWithError(0, "availability check complete")
return nil
}

28
relay/healthcheck/ws.go Normal file
View File

@@ -0,0 +1,28 @@
package healthcheck
import (
"context"
"fmt"
"github.com/coder/websocket"
"github.com/netbirdio/netbird/shared/relay"
)
func dialWS(ctx context.Context, address string) error {
url := fmt.Sprintf("wss://%s%s", address, relay.WebSocketURLPath)
conn, resp, err := websocket.Dial(ctx, url, nil)
if resp != nil {
defer func() {
_ = resp.Body.Close()
}()
}
if err != nil {
return fmt.Errorf("failed to connect to websocket: %w", err)
}
_ = conn.Close(websocket.StatusNormalClosure, "availability check complete")
return nil
}

View File

@@ -0,0 +1,3 @@
package protocol
type Protocol string

View File

@@ -3,9 +3,12 @@ package listener
import ( import (
"context" "context"
"net" "net"
"github.com/netbirdio/netbird/relay/protocol"
) )
type Listener interface { type Listener interface {
Listen(func(conn net.Conn)) error Listen(func(conn net.Conn)) error
Shutdown(ctx context.Context) error Shutdown(ctx context.Context) error
Protocol() protocol.Protocol
} }

View File

@@ -9,8 +9,12 @@ import (
"github.com/quic-go/quic-go" "github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
) )
const Proto protocol.Protocol = "quic"
type Listener struct { type Listener struct {
// Address is the address to listen on // Address is the address to listen on
Address string Address string
@@ -50,6 +54,10 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
} }
} }
func (l *Listener) Protocol() protocol.Protocol {
return Proto
}
func (l *Listener) Shutdown(ctx context.Context) error { func (l *Listener) Shutdown(ctx context.Context) error {
if l.listener == nil { if l.listener == nil {
return nil return nil

View File

@@ -11,11 +11,14 @@ import (
"github.com/coder/websocket" "github.com/coder/websocket"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/protocol"
"github.com/netbirdio/netbird/shared/relay" "github.com/netbirdio/netbird/shared/relay"
) )
// URLPath is the path for the websocket connection. const (
const URLPath = relay.WebSocketURLPath Proto protocol.Protocol = "ws"
URLPath = relay.WebSocketURLPath
)
type Listener struct { type Listener struct {
// Address is the address to listen on. // Address is the address to listen on.
@@ -51,6 +54,10 @@ func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
return err return err
} }
func (l *Listener) Protocol() protocol.Protocol {
return Proto
}
func (l *Listener) Shutdown(ctx context.Context) error { func (l *Listener) Shutdown(ctx context.Context) error {
if l.server == nil { if l.server == nil {
return nil return nil

View File

@@ -6,12 +6,14 @@ import (
"sync" "sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/relay/protocol"
"github.com/netbirdio/netbird/relay/server/listener" "github.com/netbirdio/netbird/relay/server/listener"
"github.com/netbirdio/netbird/relay/server/listener/quic" "github.com/netbirdio/netbird/relay/server/listener/quic"
"github.com/netbirdio/netbird/relay/server/listener/ws" "github.com/netbirdio/netbird/relay/server/listener/ws"
quictls "github.com/netbirdio/netbird/shared/relay/tls" quictls "github.com/netbirdio/netbird/shared/relay/tls"
log "github.com/sirupsen/logrus"
) )
// ListenerConfig is the configuration for the listener. // ListenerConfig is the configuration for the listener.
@@ -26,8 +28,11 @@ type ListenerConfig struct {
// It is the gate between the WebSocket listener and the Relay server logic. // It is the gate between the WebSocket listener and the Relay server logic.
// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. // In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
type Server struct { type Server struct {
listenAddr string
relay *Relay relay *Relay
listeners []listener.Listener listeners []listener.Listener
listenerMux sync.Mutex
} }
// NewServer creates and returns a new relay server instance. // NewServer creates and returns a new relay server instance.
@@ -57,10 +62,14 @@ func NewServer(config Config) (*Server, error) {
// Listen starts the relay server. // Listen starts the relay server.
func (r *Server) Listen(cfg ListenerConfig) error { func (r *Server) Listen(cfg ListenerConfig) error {
r.listenAddr = cfg.Address
wSListener := &ws.Listener{ wSListener := &ws.Listener{
Address: cfg.Address, Address: cfg.Address,
TLSConfig: cfg.TLSConfig, TLSConfig: cfg.TLSConfig,
} }
r.listenerMux.Lock()
r.listeners = append(r.listeners, wSListener) r.listeners = append(r.listeners, wSListener)
tlsConfigQUIC, err := quictls.ServerQUICTLSConfig(cfg.TLSConfig) tlsConfigQUIC, err := quictls.ServerQUICTLSConfig(cfg.TLSConfig)
@@ -85,6 +94,8 @@ func (r *Server) Listen(cfg ListenerConfig) error {
}(l) }(l)
} }
r.listenerMux.Unlock()
wg.Wait() wg.Wait()
close(errChan) close(errChan)
var multiErr *multierror.Error var multiErr *multierror.Error
@@ -100,12 +111,15 @@ func (r *Server) Listen(cfg ListenerConfig) error {
func (r *Server) Shutdown(ctx context.Context) error { func (r *Server) Shutdown(ctx context.Context) error {
r.relay.Shutdown(ctx) r.relay.Shutdown(ctx)
r.listenerMux.Lock()
var multiErr *multierror.Error var multiErr *multierror.Error
for _, l := range r.listeners { for _, l := range r.listeners {
if err := l.Shutdown(ctx); err != nil { if err := l.Shutdown(ctx); err != nil {
multiErr = multierror.Append(multiErr, err) multiErr = multierror.Append(multiErr, err)
} }
} }
r.listeners = r.listeners[:0]
r.listenerMux.Unlock()
return nberrors.FormatErrorOrNil(multiErr) return nberrors.FormatErrorOrNil(multiErr)
} }
@@ -113,3 +127,18 @@ func (r *Server) Shutdown(ctx context.Context) error {
func (r *Server) InstanceURL() string { func (r *Server) InstanceURL() string {
return r.relay.instanceURL return r.relay.instanceURL
} }
func (r *Server) ListenerProtocols() []protocol.Protocol {
result := make([]protocol.Protocol, 0)
r.listenerMux.Lock()
for _, l := range r.listeners {
result = append(result, l.Protocol())
}
r.listenerMux.Unlock()
return result
}
func (r *Server) ListenAddress() string {
return r.listenAddr
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
@@ -111,7 +112,9 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err) t.Fatal(err)
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) groupsManager := groups.NewManagerMock()
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{}) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -162,6 +162,12 @@ components:
description: Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored. description: Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored.
type: boolean type: boolean
example: true example: true
network_traffic_logs_groups:
description: Limits traffic logging to these groups. If unset all peers are enabled.
type: array
items:
type: string
example: ch8i4ug6lnn4g9hqv7m0
network_traffic_packet_counter_enabled: network_traffic_packet_counter_enabled:
description: Enables or disables network traffic packet counter. If enabled, network packets and their size will be counted and reported. (This can have an slight impact on performance) description: Enables or disables network traffic packet counter. If enabled, network packets and their size will be counted and reported. (This can have an slight impact on performance)
type: boolean type: boolean
@@ -169,6 +175,7 @@ components:
required: required:
- peer_approval_enabled - peer_approval_enabled
- network_traffic_logs_enabled - network_traffic_logs_enabled
- network_traffic_logs_groups
- network_traffic_packet_counter_enabled - network_traffic_packet_counter_enabled
AccountRequest: AccountRequest:
type: object type: object

View File

@@ -260,6 +260,9 @@ type AccountExtraSettings struct {
// NetworkTrafficLogsEnabled Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored. // NetworkTrafficLogsEnabled Enables or disables network traffic logging. If enabled, all network traffic events from peers will be stored.
NetworkTrafficLogsEnabled bool `json:"network_traffic_logs_enabled"` NetworkTrafficLogsEnabled bool `json:"network_traffic_logs_enabled"`
// NetworkTrafficLogsGroups Limits traffic logging to these groups. If unset all peers are enabled.
NetworkTrafficLogsGroups []string `json:"network_traffic_logs_groups"`
// NetworkTrafficPacketCounterEnabled Enables or disables network traffic packet counter. If enabled, network packets and their size will be counted and reported. (This can have an slight impact on performance) // NetworkTrafficPacketCounterEnabled Enables or disables network traffic packet counter. If enabled, network packets and their size will be counted and reported. (This can have an slight impact on performance)
NetworkTrafficPacketCounterEnabled bool `json:"network_traffic_packet_counter_enabled"` NetworkTrafficPacketCounterEnabled bool `json:"network_traffic_packet_counter_enabled"`

View File

@@ -1,3 +1,3 @@
package tls package tls
const nbalpn = "nb-quic" const NBalpn = "nb-quic"

View File

@@ -20,7 +20,7 @@ func ClientQUICTLSConfig() *tls.Config {
return &tls.Config{ return &tls.Config{
InsecureSkipVerify: true, // Debug mode allows insecure connections InsecureSkipVerify: true, // Debug mode allows insecure connections
NextProtos: []string{nbalpn}, // Ensure this matches the server's ALPN NextProtos: []string{NBalpn}, // Ensure this matches the server's ALPN
RootCAs: certPool, RootCAs: certPool,
} }
} }

View File

@@ -19,7 +19,7 @@ func ClientQUICTLSConfig() *tls.Config {
} }
return &tls.Config{ return &tls.Config{
NextProtos: []string{nbalpn}, NextProtos: []string{NBalpn},
RootCAs: certPool, RootCAs: certPool,
} }
} }

View File

@@ -23,7 +23,7 @@ func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
} }
cfg := originTLSCfg.Clone() cfg := originTLSCfg.Clone()
cfg.NextProtos = []string{nbalpn} cfg.NextProtos = []string{NBalpn}
return cfg, nil return cfg, nil
} }
@@ -74,6 +74,6 @@ func generateTestTLSConfig() (*tls.Config, error) {
return &tls.Config{ return &tls.Config{
Certificates: []tls.Certificate{tlsCert}, Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{nbalpn}, NextProtos: []string{NBalpn},
}, nil }, nil
} }

View File

@@ -12,6 +12,6 @@ func ServerQUICTLSConfig(originTLSCfg *tls.Config) (*tls.Config, error) {
return nil, fmt.Errorf("valid TLS config is required for QUIC listener") return nil, fmt.Errorf("valid TLS config is required for QUIC listener")
} }
cfg := originTLSCfg.Clone() cfg := originTLSCfg.Clone()
cfg.NextProtos = []string{nbalpn} cfg.NextProtos = []string{NBalpn}
return cfg, nil return cfg, nil
} }

View File

@@ -52,7 +52,7 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
} }
// MarshalCredential marshal a Credential instance and returns a Message object // MarshalCredential marshal a Credential instance and returns a Message object
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string) (*proto.Message, error) { func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string, sessionID []byte) (*proto.Message, error) {
return &proto.Message{ return &proto.Message{
Key: myKey.PublicKey().String(), Key: myKey.PublicKey().String(),
RemoteKey: remoteKey, RemoteKey: remoteKey,
@@ -66,6 +66,7 @@ func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credenti
RosenpassServerAddr: rosenpassAddr, RosenpassServerAddr: rosenpassAddr,
}, },
RelayServerAddress: relaySrvAddress, RelayServerAddress: relaySrvAddress,
SessionId: sessionID,
}, },
}, nil }, nil
} }

View File

@@ -45,19 +45,10 @@ type GrpcClient struct {
connStateCallbackLock sync.RWMutex connStateCallbackLock sync.RWMutex
onReconnectedListenerFn func() onReconnectedListenerFn func()
}
func (c *GrpcClient) StreamConnected() bool { decryptionWorker *Worker
return c.status == StreamConnected decryptionWorkerCancel context.CancelFunc
} decryptionWg sync.WaitGroup
func (c *GrpcClient) GetStatus() Status {
return c.status
}
// Close Closes underlying connections to the Signal Exchange
func (c *GrpcClient) Close() error {
return c.signalConn.Close()
} }
// NewClient creates a new Signal client // NewClient creates a new Signal client
@@ -93,6 +84,25 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
}, nil }, nil
} }
func (c *GrpcClient) StreamConnected() bool {
return c.status == StreamConnected
}
func (c *GrpcClient) GetStatus() Status {
return c.status
}
// Close Closes underlying connections to the Signal Exchange
func (c *GrpcClient) Close() error {
if c.decryptionWorkerCancel != nil {
c.decryptionWorkerCancel()
}
c.decryptionWg.Wait()
c.decryptionWorker = nil
return c.signalConn.Close()
}
// SetConnStateListener set the ConnStateNotifier // SetConnStateListener set the ConnStateNotifier
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) { func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallbackLock.Lock() c.connStateCallbackLock.Lock()
@@ -148,8 +158,12 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes
log.Infof("connected to the Signal Service stream") log.Infof("connected to the Signal Service stream")
c.notifyConnected() c.notifyConnected()
// Start worker pool if not already started
c.startEncryptionWorker(msgHandler)
// start receiving messages from the Signal stream (from other peers through signal) // start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream, msgHandler) err = c.receive(stream)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled { if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
log.Debugf("signal connection context has been canceled, this usually indicates shutdown") log.Debugf("signal connection context has been canceled, this usually indicates shutdown")
@@ -174,6 +188,7 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes
return nil return nil
} }
func (c *GrpcClient) notifyStreamDisconnected() { func (c *GrpcClient) notifyStreamDisconnected() {
c.mux.Lock() c.mux.Lock()
defer c.mux.Unlock() defer c.mux.Unlock()
@@ -382,11 +397,11 @@ func (c *GrpcClient) Send(msg *proto.Message) error {
} }
// receive receives messages from other peers coming through the Signal Exchange // receive receives messages from other peers coming through the Signal Exchange
func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient, // and distributes them to worker threads for processing
msgHandler func(msg *proto.Message) error) error { func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient) error {
for { for {
msg, err := stream.Recv() msg, err := stream.Recv()
// Handle errors immediately
switch s, ok := status.FromError(err); { switch s, ok := status.FromError(err); {
case ok && s.Code() == codes.Canceled: case ok && s.Code() == codes.Canceled:
log.Debugf("stream canceled (usually indicates shutdown)") log.Debugf("stream canceled (usually indicates shutdown)")
@@ -398,22 +413,35 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
log.Debugf("Signal Service stream closed by server") log.Debugf("Signal Service stream closed by server")
return err return err
case err != nil: case err != nil:
log.Errorf("Stream receive error: %v", err)
return err return err
} }
log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
decryptedMessage, err := c.decryptMessage(msg) if msg == nil {
if err != nil { continue
log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
} }
err = msgHandler(decryptedMessage) if err := c.decryptionWorker.AddMsg(c.ctx, msg); err != nil {
log.Errorf("failed to add message to decryption worker: %v", err)
}
}
}
if err != nil { func (c *GrpcClient) startEncryptionWorker(handler func(msg *proto.Message) error) {
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error()) if c.decryptionWorker != nil {
// todo send something?? return
}
} }
c.decryptionWorker = NewWorker(c.decryptMessage, handler)
workerCtx, workerCancel := context.WithCancel(context.Background())
c.decryptionWorkerCancel = workerCancel
c.decryptionWg.Add(1)
go func() {
defer workerCancel()
c.decryptionWorker.Work(workerCtx)
c.decryptionWg.Done()
}()
} }
func (c *GrpcClient) notifyDisconnected(err error) { func (c *GrpcClient) notifyDisconnected(err error) {

View File

@@ -0,0 +1,55 @@
package client
import (
"context"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/signal/proto"
)
type Worker struct {
decryptMessage func(msg *proto.EncryptedMessage) (*proto.Message, error)
handler func(msg *proto.Message) error
encryptedMsgPool chan *proto.EncryptedMessage
}
func NewWorker(decryptFn func(msg *proto.EncryptedMessage) (*proto.Message, error), handlerFn func(msg *proto.Message) error) *Worker {
return &Worker{
decryptMessage: decryptFn,
handler: handlerFn,
encryptedMsgPool: make(chan *proto.EncryptedMessage, 1),
}
}
func (w *Worker) AddMsg(ctx context.Context, msg *proto.EncryptedMessage) error {
// this is blocker because do not want to drop messages here
select {
case w.encryptedMsgPool <- msg:
case <-ctx.Done():
}
return nil
}
func (w *Worker) Work(ctx context.Context) {
for {
select {
case msg := <-w.encryptedMsgPool:
decryptedMessage, err := w.decryptMessage(msg)
if err != nil {
log.Errorf("failed to decrypt message: %v", err)
continue
}
if err := w.handler(decryptedMessage); err != nil {
log.Errorf("failed to handle message: %v", err)
continue
}
case <-ctx.Done():
log.Infof("Message worker stopping due to context cancellation")
return
}
}
}

View File

@@ -230,6 +230,7 @@ type Body struct {
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"` RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
// relayServerAddress is url of the relay server // relayServerAddress is url of the relay server
RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"` RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"`
SessionId []byte `protobuf:"bytes,10,opt,name=sessionId,proto3,oneof" json:"sessionId,omitempty"`
} }
func (x *Body) Reset() { func (x *Body) Reset() {
@@ -320,6 +321,13 @@ func (x *Body) GetRelayServerAddress() string {
return "" return ""
} }
func (x *Body) GetSessionId() []byte {
if x != nil {
return x.SessionId
}
return nil
}
// Mode indicates a connection mode // Mode indicates a connection mode
type Mode struct { type Mode struct {
state protoimpl.MessageState state protoimpl.MessageState
@@ -443,7 +451,7 @@ var file_signalexchange_proto_rawDesc = []byte{
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62, 0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xb3, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xe4, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a, 0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
@@ -466,34 +474,37 @@ var file_signalexchange_proto_rawDesc = []byte{
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01,
0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41,
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x21, 0x0a, 0x09, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f,
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x6e, 0x49, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0c, 0x48, 0x00, 0x52, 0x09, 0x73, 0x65, 0x73,
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x88, 0x01, 0x01, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70,
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x12, 0x0b, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06,
0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x22, 0x2e, 0x0a, 0x04, 0x4d, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44,
0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10,
0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01, 0x04, 0x12, 0x0b, 0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x42, 0x0c,
0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, 0x52, 0x0a, 0x0a, 0x5f, 0x73, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x64, 0x22, 0x2e, 0x0a, 0x04,
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28, 0x4d, 0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01,
0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x20, 0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01,
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x01, 0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f,
0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12,
0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18, 0x28, 0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b,
0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72,
0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e,
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c,
0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x0a, 0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65,
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65,
0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70,
0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d,
0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e,
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45,
0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a,
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67,
0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
} }
var ( var (
@@ -601,6 +612,7 @@ func file_signalexchange_proto_init() {
} }
} }
} }
file_signalexchange_proto_msgTypes[2].OneofWrappers = []interface{}{}
file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{} file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{}
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{

View File

@@ -64,6 +64,8 @@ message Body {
// relayServerAddress is url of the relay server // relayServerAddress is url of the relay server
string relayServerAddress = 8; string relayServerAddress = 8;
optional bytes sessionId = 10;
} }
// Mode indicates a connection mode // Mode indicates a connection mode