Compare commits

...

43 Commits

Author SHA1 Message Date
crn4
c8a9af2482 Merge branch 'main' into vk/debug/nmap-both 2025-11-11 14:32:04 +01:00
crn4
4bcda3e2ba use old map, new in goroutine 2025-11-11 14:31:47 +01:00
Zoltan Papp
c28275611b Fix agent reference (#4776) 2025-11-11 13:59:32 +01:00
Vlad
56f169eede [management] fix pg db deadlock after app panic (#4772) 2025-11-10 23:43:08 +01:00
crn4
b780f1c09d add both network maps compilation for debug 2025-11-10 21:19:07 +01:00
Viktor Liu
07cf9d5895 [client] Create networkd.conf.d if it doesn't exist (#4764) 2025-11-08 10:54:37 +01:00
Pascal Fischer
7df49e249d [management ] remove timing logs (#4761) 2025-11-07 20:14:52 +01:00
Pascal Fischer
dbfc8a52c9 [management] remove GLOBAL when disabling foreign keys on mysql (#4615) 2025-11-07 16:03:14 +01:00
Vlad
98ddac07bf [management] remove toAll firewall rule (#4725) 2025-11-07 15:50:58 +01:00
Pascal Fischer
48475ddc05 [management] add pat rate limiting (#4741) 2025-11-07 15:50:18 +01:00
Vlad
6aa4ba7af4 [management] incremental network map builder (#4753) 2025-11-07 10:44:46 +01:00
dependabot[bot]
2e16c9914a [management] Bump github.com/containerd/containerd from 1.7.27 to 1.7.29 (#4756)
Bumps [github.com/containerd/containerd](https://github.com/containerd/containerd) from 1.7.27 to 1.7.29.
- [Release notes](https://github.com/containerd/containerd/releases)
- [Changelog](https://github.com/containerd/containerd/blob/main/RELEASES.md)
- [Commits](https://github.com/containerd/containerd/compare/v1.7.27...v1.7.29)

---
updated-dependencies:
- dependency-name: github.com/containerd/containerd
  dependency-version: 1.7.29
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-06 19:01:44 +03:00
Pascal Fischer
5c29d395b2 [management] activity events on group updates (#4750) 2025-11-06 12:51:14 +01:00
Viktor Liu
229e0038ee [client] Add dns config to debug bundle (#4704) 2025-11-05 17:30:17 +01:00
Viktor Liu
75327d9519 [client] Add login_hint to oidc flows (#4724) 2025-11-05 17:00:20 +01:00
Viktor Liu
c92e6c1b5f [client] Block on all subsystems on shutdown (#4709) 2025-11-05 12:15:37 +01:00
Viktor Liu
641eb5140b [client] Allow INPUT traffic on the compat iptables filter table for nftables (#4742) 2025-11-04 21:56:53 +01:00
Viktor Liu
45c25dca84 [client] Clamp MSS on outbound traffic (#4735) 2025-11-04 17:18:51 +01:00
Viktor Liu
679c58ce47 [client] Set up networkd to ignore ip rules (#4730) 2025-11-04 17:06:35 +01:00
Pascal Fischer
719283c792 [management] update db connection lifecycle configuration (#4740) 2025-11-03 17:40:12 +01:00
dependabot[bot]
a2313a5ba4 [client] Bump github.com/quic-go/quic-go from 0.48.2 to 0.49.1 (#4621)
Bumps [github.com/quic-go/quic-go](https://github.com/quic-go/quic-go) from 0.48.2 to 0.49.1.
- [Release notes](https://github.com/quic-go/quic-go/releases)
- [Commits](https://github.com/quic-go/quic-go/compare/v0.48.2...v0.49.1)

---
updated-dependencies:
- dependency-name: github.com/quic-go/quic-go
  dependency-version: 0.49.1
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-01 15:27:22 +01:00
Zoltan Papp
8c108ccad3 [client] Extend Darwin network monitoring with wakeup detection 2025-10-31 19:19:14 +01:00
Viktor Liu
86eff0d750 [client] Fix netstack dns forwarder (#4727) 2025-10-31 14:18:09 +01:00
Viktor Liu
43c9a51913 [client] Migrate deprecated grpc client code (#4687) 2025-10-30 10:14:27 +01:00
Viktor Liu
c530db1455 [client] Fix UI panic when switching profiles (#4718) 2025-10-29 17:27:18 +01:00
Viktor Liu
1ee575befe [client] Use management-provided dns forwarder port on the client side (#4712) 2025-10-28 22:58:43 +01:00
Viktor Liu
d3a34adcc9 [client] Fix Connect/Disconnect buttons being enabled or disabled at the same time (#4711) 2025-10-28 21:21:40 +01:00
Zoltan Papp
d7321c130b [client] The status cmd will not be blocked by the ICE probe (#4597)
The status cmd will not be blocked by the ICE probe

Refactor the TURN and STUN probe, and cache the results. The NetBird status command will indicate a "checking…" state.
2025-10-28 16:11:35 +01:00
Viktor Liu
404cab90ba [client] Redirect dns forwarder port 5353 to new listening port 22054 (#4707)
- Port dnat changes from https://github.com/netbirdio/netbird/pull/4015 (nftables/iptables/userspace)
  - For userspace: rewrite the original port to the target port
  - Remember original destination port in conntrack
  - Rewrite the source port back to the original port for replies
- Redirect incoming port 5353 to 22054 (tcp/udp)
- Revert port changes based on the network map received from management
- Adjust tracer to show NAT stages
2025-10-28 15:12:53 +01:00
Pascal Fischer
4545ab9a52 [management] rewire account manager to permissions manager (#4673) 2025-10-27 22:59:35 +01:00
Bethuel Mmbaga
7f08983207 Include expired and routing peers in DNS record filtering (#4708) 2025-10-27 22:16:17 +03:00
Viktor Liu
eddea14521 [client] Clean up bsd routes independently of the state file (#4688) 2025-10-27 18:54:00 +01:00
Viktor Liu
b9ef214ea5 [client] Fix macOS state-based dns cleanup (#4701) 2025-10-27 18:35:32 +01:00
Bethuel Mmbaga
709e24eb6f [signal] Fix HTTP/WebSocket proxy not using custom certificates (#4644)
This pull request fixes a bug where the HTTP/WebSocket proxy server was not using custom TLS certificates when provided via --cert-file and --cert-key flags. Previously, only the gRPC server had TLS enabled with custom certificates, while the HTTP/WebSocket proxy ran without TLS.
2025-10-24 15:40:20 +03:00
Viktor Liu
6654e2dbf7 [client] Fix active profile name in debug bundle (#4689) 2025-10-23 17:07:52 +02:00
Bethuel Mmbaga
d80d47a469 [management] Add peer disapproval reason (#4468) 2025-10-22 12:46:22 +03:00
Maycon Santos
96f71ff1e1 [misc] Update tag name extraction in install.sh (#4677) 2025-10-21 19:23:11 +02:00
Viktor Liu
2fe2af38d2 [client] Clean up match domain reg entries between config changes (#4676) 2025-10-21 18:14:39 +02:00
Misha Bragin
cd9a867ad0 [client] Delete TURNConfig section from script (#4639) 2025-10-17 19:48:26 +02:00
Maycon Santos
0f9bfeff7c [client] Security upgrade alpine from 3.22.0 to 3.22.2 #4618 2025-10-17 19:47:11 +02:00
Viktor Liu
f5301230bf [client] Fix status showing P2P without connection (#4661) 2025-10-17 13:31:15 +02:00
Viktor Liu
429d7d6585 [client] Support BROWSER env for login (#4654) 2025-10-17 11:10:16 +02:00
Viktor Liu
3cdb10cde7 [client] Remove rule squashing (#4653) 2025-10-17 11:09:39 +02:00
164 changed files with 11792 additions and 1762 deletions

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.0 FROM alpine:3.22.2
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \ RUN apk add --no-cache \
bash \ bash \

View File

@@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
} }
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) { func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false) oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{}) stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true})
if err != nil { if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
} }
@@ -303,12 +303,18 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error {
func getStatusOutput(cmd *cobra.Command, anon bool) string { func getStatusOutput(cmd *cobra.Command, anon bool) string {
var statusOutputString string var statusOutputString string
statusResp, err := getStatus(cmd.Context()) statusResp, err := getStatus(cmd.Context(), true)
if err != nil { if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err) cmd.PrintErrf("Failed to get status: %v\n", err)
} else { } else {
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusOutputString = nbstatus.ParseToFullDetailSummary( statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
) )
} }
return statusOutputString return statusOutputString

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"os/exec"
"os/user" "os/user"
"runtime" "runtime"
"strings" "strings"
@@ -105,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
Username: &username, Username: &username,
} }
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
loginRequest.OptionalPreSharedKey = &preSharedKey loginRequest.OptionalPreSharedKey = &preSharedKey
} }
@@ -240,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err) return fmt.Errorf("read config file %s: %v", configFilePath, err)
} }
err = foregroundLogin(ctx, cmd, config, setupKey) err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -268,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil return nil
} }
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error { func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
needsLogin := false needsLogin := false
err := WithBackOff(func() error { err := WithBackOff(func() error {
@@ -285,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken := "" jwtToken := ""
if setupKey == "" && needsLogin { if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config) tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
if err != nil { if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err) return fmt.Errorf("interactive sso login failed: %v", err)
} }
@@ -314,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
return nil return nil
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop()) hint := ""
pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileName)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
hint = profileState.Email
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -356,13 +373,21 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("") cmd.Println("")
if !noBrowser { if !noBrowser {
if err := open.Run(verificationURIComplete); err != nil { if err := openBrowser(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys") "https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} }
} }
} }
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment // isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool { func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -10,6 +10,8 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
log "github.com/sirupsen/logrus"
"github.com/kardianos/service" "github.com/kardianos/service"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@@ -81,6 +83,10 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error {
svcConfig.Option["LogDirectory"] = dir svcConfig.Option["LogDirectory"] = dir
} }
} }
if err := configureSystemdNetworkd(); err != nil {
log.Warnf("failed to configure systemd-networkd: %v", err)
}
} }
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
@@ -160,6 +166,12 @@ var uninstallCmd = &cobra.Command{
return fmt.Errorf("uninstall service: %w", err) return fmt.Errorf("uninstall service: %w", err)
} }
if runtime.GOOS == "linux" {
if err := cleanupSystemdNetworkd(); err != nil {
log.Warnf("failed to cleanup systemd-networkd configuration: %v", err)
}
}
cmd.Println("NetBird service has been uninstalled") cmd.Println("NetBird service has been uninstalled")
return nil return nil
}, },
@@ -245,3 +257,50 @@ func isServiceRunning() (bool, error) {
return status == service.StatusRunning, nil return status == service.StatusRunning, nil
} }
const (
networkdConf = "/etc/systemd/networkd.conf"
networkdConfDir = "/etc/systemd/networkd.conf.d"
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
# routes and policy rules managed by NetBird.
[Network]
ManageForeignRoutes=no
ManageForeignRoutingPolicyRules=no
`
)
// configureSystemdNetworkd creates a drop-in configuration file to prevent
// systemd-networkd from removing NetBird's routes and policy rules.
func configureSystemdNetworkd() error {
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
log.Debug("systemd-networkd not in use, skipping configuration")
return nil
}
// nolint:gosec // standard networkd permissions
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
return fmt.Errorf("create networkd.conf.d directory: %w", err)
}
// nolint:gosec // standard networkd permissions
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
return fmt.Errorf("write networkd configuration: %w", err)
}
return nil
}
// cleanupSystemdNetworkd removes the NetBird systemd-networkd configuration file.
func cleanupSystemdNetworkd() error {
if _, err := os.Stat(networkdConfFile); os.IsNotExist(err) {
return nil
}
if err := os.Remove(networkdConfFile); err != nil {
return fmt.Errorf("remove networkd configuration: %w", err)
}
return nil
}

View File

@@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(cmd.Context())
resp, err := getStatus(ctx) resp, err := getStatus(ctx, false)
if err != nil { if err != nil {
return err return err
} }
@@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func getStatus(ctx context.Context) (*proto.StatusResponse, error) { func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
} }
defer conn.Close() defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true}) resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
if err != nil { if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
} }

View File

@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath) _, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey) err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
loginRequest.ProfileName = &activeProf.Name loginRequest.ProfileName = &activeProf.Name
loginRequest.Username = &username loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
var loginErr error var loginErr error
var loginResp *proto.LoginResponse var loginResp *proto.LoginResponse

View File

@@ -15,13 +15,13 @@ import (
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }
// use userspace packet filtering firewall // use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger) fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall // for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes) fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return fm, err return fm, err
@@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
if err != nil { if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
} }
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger) return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
} }
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
fm, err := createFW(iface) fm, err := createFW(iface, mtu)
if err != nil { if err != nil {
return nil, fmt.Errorf("create firewall: %s", err) return nil, fmt.Errorf("create firewall: %s", err)
} }
@@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager,
return fm, nil return fm, nil
} }
func createFW(iface IFaceMapper) (firewall.Manager, error) { func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
switch check() { switch check() {
case IPTABLES: case IPTABLES:
log.Info("creating an iptables firewall manager") log.Info("creating an iptables firewall manager")
return nbiptables.Create(iface) return nbiptables.Create(iface, mtu)
case NFTABLES: case NFTABLES:
log.Info("creating an nftables firewall manager") log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface) return nbnftables.Create(iface, mtu)
default: default:
log.Info("no firewall manager found, trying to use userspace packet filtering firewall") log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found") return nil, errors.New("no firewall manager found")
} }
} }
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) { func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
var errUsp error var errUsp error
if fm != nil { if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger) fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
} else { } else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger) fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
} }
if errUsp != nil { if errUsp != nil {

View File

@@ -400,7 +400,6 @@ func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action fi
return "" return ""
} }
// Include action in the ipset name to prevent squashing rules with different actions
actionSuffix := "" actionSuffix := ""
if action == firewall.ActionDrop { if action == firewall.ActionDrop {
actionSuffix = "-drop" actionSuffix = "-drop"

View File

@@ -36,7 +36,7 @@ type iFaceMapper interface {
} }
// Create iptables firewall manager // Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("init iptables: %w", err) return nil, fmt.Errorf("init iptables: %w", err)
@@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient, ipv4Client: iptablesClient,
} }
m.router, err = newRouter(iptablesClient, wgIface) m.router, err = newRouter(iptablesClient, wgIface, mtu)
if err != nil { if err != nil {
return nil, fmt.Errorf("create router: %w", err) return nil, fmt.Errorf("create router: %w", err)
} }
@@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
NameStr: m.wgIface.Name(), NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(), WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(), UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
}, },
} }
stateManager.RegisterState(state) stateManager.RegisterState(state)
@@ -260,6 +261,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes) return m.router.UpdateSet(set, prefixes)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// just check on the local interface // just check on the local interface
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err) require.NoError(t, err)
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@@ -30,17 +30,20 @@ const (
chainPOSTROUTING = "POSTROUTING" chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING" chainPREROUTING = "PREROUTING"
chainFORWARD = "FORWARD"
chainRTNAT = "NETBIRD-RT-NAT" chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWDIN = "NETBIRD-RT-FWD-IN" chainRTFWDIN = "NETBIRD-RT-FWD-IN"
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT" chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE" chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR" chainRTRDR = "NETBIRD-RT-RDR"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
jumpManglePre = "jump-mangle-pre" jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre" jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post" jumpNatPost = "jump-nat-post"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre" markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post" markManglePost = "mark-mangle-post"
matchSet = "--match-set" matchSet = "--match-set"
@@ -48,6 +51,9 @@ const (
dnatSuffix = "_dnat" dnatSuffix = "_dnat"
snatSuffix = "_snat" snatSuffix = "_snat"
fwdSuffix = "_fwd" fwdSuffix = "_fwd"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
) )
type ruleInfo struct { type ruleInfo struct {
@@ -77,16 +83,18 @@ type router struct {
ipsetCounter *ipsetCounter ipsetCounter *ipsetCounter
wgIface iFaceMapper wgIface iFaceMapper
legacyManagement bool legacyManagement bool
mtu uint16
stateManager *statemanager.Manager stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
} }
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
r := &router{ r := &router{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface, wgIface: wgIface,
mtu: mtu,
ipFwdState: ipfwdstate.NewIPForwardingState(), ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
@@ -392,6 +400,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
{chainRTPRE, tableMangle}, {chainRTPRE, tableMangle},
{chainRTNAT, tableNat}, {chainRTNAT, tableNat},
{chainRTRDR, tableNat}, {chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} { } {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil { if err != nil {
@@ -416,6 +425,7 @@ func (r *router) createContainers() error {
{chainRTPRE, tableMangle}, {chainRTPRE, tableMangle},
{chainRTNAT, tableNat}, {chainRTNAT, tableNat},
{chainRTRDR, tableNat}, {chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} { } {
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
@@ -438,6 +448,10 @@ func (r *router) createContainers() error {
return fmt.Errorf("add jump rules: %w", err) return fmt.Errorf("add jump rules: %w", err)
} }
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
return nil return nil
} }
@@ -518,6 +532,35 @@ func (r *router) addPostroutingRules() error {
return nil return nil
} }
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize
// Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{
"-j", chainRTMSSCLAMP,
}
if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil {
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
}
r.rules[jumpMSSClamp] = jumpRule
ruleOut := []string{
"-o", r.wgIface.Name(),
"-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mss),
}
if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil {
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
}
r.rules["mss-clamp-out"] = ruleOut
return nil
}
func (r *router) insertEstablishedRule(chain string) error { func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished() establishedRule := getConntrackEstablished()
@@ -558,7 +601,7 @@ func (r *router) addJumpRules() error {
} }
func (r *router) cleanJumpRules() error { func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} { for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} {
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
var table, chain string var table, chain string
switch ruleKey { switch ruleKey {
@@ -571,6 +614,9 @@ func (r *router) cleanJumpRules() error {
case jumpNatPre: case jumpNatPre:
table = tableNat table = tableNat
chain = chainPREROUTING chain = chainPREROUTING
case jumpMSSClamp:
table = tableMangle
chain = chainFORWARD
default: default:
return fmt.Errorf("unknown jump rule: %s", ruleKey) return fmt.Errorf("unknown jump rule: %s", ruleKey)
} }
@@ -880,6 +926,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(string(protocol)),
"--dport", strconv.Itoa(int(sourcePort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(targetPort)),
}
ruleInfo := ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = ruleInfo.rule
r.updateState()
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
return fmt.Errorf("delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
func applyPort(flag string, port *firewall.Port) []string { func applyPort(flag string, port *firewall.Port) []string {
if port == nil { if port == nil {
return nil return nil

View File

@@ -14,6 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
@@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock) manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil)) require.NoError(t, manager.init(nil))
@@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
assert.NoError(t, manager.Reset(), "shouldn't return error") assert.NoError(t, manager.Reset(), "shouldn't return error")
}() }()
// Now 5 rules:
// 1. established rule forward in // 1. established rule forward in
// 2. estbalished rule forward out // 2. estbalished rule forward out
// 3. jump rule to POST nat chain // 3. jump rule to POST nat chain
@@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 7. static return masquerade rule // 7. static return masquerade rule
// 8. mangle prerouting mark rule // 8. mangle prerouting mark rule
// 9. mangle postrouting mark rule // 9. mangle postrouting mark rule
require.Len(t, manager.rules, 9, "should have created rules map") // 10. jump rule to MSS clamping chain
// 11. MSS clamping rule for outbound traffic
require.Len(t, manager.rules, 11, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock) manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil)) require.NoError(t, manager.init(nil))
@@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(iptablesClient, ifaceMock) manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil)) require.NoError(t, manager.init(nil))
defer func() { defer func() {
@@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client") require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(iptablesClient, ifaceMock) r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router manager") require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil)) require.NoError(t, r.init(nil))

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -11,6 +12,7 @@ type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
@@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string {
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState) mtu := s.InterfaceState.MTU
if mtu == 0 {
mtu = iface.DefaultMTU
}
ipt, err := Create(s.InterfaceState, mtu)
if err != nil { if err != nil {
return fmt.Errorf("create iptables manager: %w", err) return fmt.Errorf("create iptables manager: %w", err)
} }

View File

@@ -100,6 +100,9 @@ type Manager interface {
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
//
// Note: Callers should call Flush() after adding rules to ensure
// they are applied to the kernel and rule handles are refreshed.
AddPeerFiltering( AddPeerFiltering(
id []byte, id []byte,
ip net.IP, ip net.IP,
@@ -151,14 +154,20 @@ type Manager interface {
DisableRouting() error DisableRouting() error
// AddDNATRule adds a DNAT rule // AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network.
AddDNATRule(ForwardRule) (Rule, error) AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule // DeleteDNATRule deletes the outbound DNAT rule.
DeleteDNATRule(Rule) error DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes // UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error UpdateSet(hash Set, prefixes []netip.Prefix) error
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services
AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
// RemoveInboundDNAT removes inbound DNAT rule
RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {

View File

@@ -29,8 +29,6 @@ const (
chainNameForwardFilter = "netbird-acl-forward-filter" chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting" chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting" chainNameManglePostrouting = "netbird-mangle-postrouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
const flushError = "flush: %w" const flushError = "flush: %w"
@@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
// createDefaultAllowRules creates default allow rules for the input and output chains // createDefaultAllowRules creates default allow rules for the input and output chains
func (m *AclManager) createDefaultAllowRules() error { func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{ expIn := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
// mask
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
}, },
@@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering(
action firewall.Action, action firewall.Action,
ipset *nftables.Set, ipset *nftables.Set,
) (*Rule, error) { ) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
return &Rule{ return &Rule{
nftRule: r.nftRule, nftRule: r.nftRule,
@@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering(
} }
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err) return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
} }
ruleStruct := &Rule{ ruleStruct := &Rule{
nftRule: nftRule, nftRule: nftRule,
// best effort mangle rule
mangleRule: m.createPreroutingRule(expressions, userData), mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset, nftSet: ipset,
ruleID: ruleId, ruleID: ruleId,
@@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt
}, },
) )
return m.rConn.AddRule(&nftables.Rule{ nfRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: m.chainPrerouting, Chain: m.chainPrerouting,
Exprs: preroutingExprs, Exprs: preroutingExprs,
UserData: userData, UserData: userData,
}) })
if err := m.rConn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}
return nfRule
} }
func (m *AclManager) createDefaultChains() (err error) { func (m *AclManager) createDefaultChains() (err error) {
@@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro
return nil return nil
} }
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":" rulesetID := ":" + string(proto) + ":"
if sPort != nil { if sPort != nil {
rulesetID += sPort.String() rulesetID += sPort.String()
} }

View File

@@ -1,11 +1,11 @@
package nftables package nftables
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"sync" "sync"
"github.com/google/nftables" "github.com/google/nftables"
@@ -19,13 +19,22 @@ import (
) )
const ( const (
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client // tableNameNetbird is the default name of the table that is used for filtering by the Netbird client
tableNameNetbird = "netbird" tableNameNetbird = "netbird"
// envTableName is the environment variable to override the table name
envTableName = "NB_NFTABLES_TABLE"
tableNameFilter = "filter" tableNameFilter = "filter"
chainNameInput = "INPUT" chainNameInput = "INPUT"
) )
func getTableName() string {
if name := os.Getenv(envTableName); name != "" {
return name
}
return tableNameNetbird
}
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
@@ -44,16 +53,16 @@ type Manager struct {
} }
// Create nftables firewall manager // Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
m := &Manager{ m := &Manager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
wgIface: wgIface, wgIface: wgIface,
} }
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
var err error var err error
m.router, err = newRouter(workTable, wgIface) m.router, err = newRouter(workTable, wgIface, mtu)
if err != nil { if err != nil {
return nil, fmt.Errorf("create router: %w", err) return nil, fmt.Errorf("create router: %w", err)
} }
@@ -93,6 +102,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
NameStr: m.wgIface.Name(), NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(), WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(), UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
}, },
}); err != nil { }); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)
@@ -197,44 +207,11 @@ func (m *Manager) AllowNetbird() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
err := m.aclManager.createDefaultAllowRules() if err := m.aclManager.createDefaultAllowRules(); err != nil {
if err != nil { return fmt.Errorf("create default allow rules: %w", err)
return fmt.Errorf("failed to create default allow rules: %v", err)
} }
if err := m.rConn.Flush(); err != nil {
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) return fmt.Errorf("flush allow input netbird rules: %w", err)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
chain = c
break
}
}
if chain == nil {
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
return nil
}
rules, err := m.rConn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
return nil
}
m.applyAllowNetbirdRules(chain)
err = m.rConn.Flush()
if err != nil {
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
} }
return nil return nil
@@ -250,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if err := m.resetNetbirdInputRules(); err != nil {
return fmt.Errorf("reset netbird input rules: %v", err)
}
if err := m.router.Reset(); err != nil { if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err) return fmt.Errorf("reset router: %v", err)
} }
@@ -273,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
return nil return nil
} }
func (m *Manager) resetNetbirdInputRules() error {
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
m.deleteNetbirdInputRules(chains)
return nil
}
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
m.deleteMatchingRules(rules)
}
}
}
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
func (m *Manager) cleanupNetbirdTables() error { func (m *Manager) cleanupNetbirdTables() error {
tables, err := m.rConn.ListTables() tables, err := m.rConn.ListTables()
if err != nil { if err != nil {
return fmt.Errorf("list tables: %w", err) return fmt.Errorf("list tables: %w", err)
} }
tableName := getTableName()
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
m.rConn.DelTable(t) m.rConn.DelTable(t)
} }
} }
@@ -376,61 +315,40 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return m.router.UpdateSet(set, prefixes) return m.router.UpdateSet(set, prefixes)
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("list of tables: %w", err) return nil, fmt.Errorf("list of tables: %w", err)
} }
tableName := getTableName()
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
m.rConn.DelTable(t) m.rConn.DelTable(t)
} }
} }
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush() err = m.rConn.Flush()
return table, err return table, err
} }
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
rule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(allowNetbirdInputRuleID),
}
_ = m.rConn.InsertRule(rule)
}
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
}
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
continue
}
return rule
}
}
}
return nil
}
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) { func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
rule := &nftables.Rule{ rule := &nftables.Rule{
Table: table, Table: table,

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) { func TestNftablesManager(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
@@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
func TestNftablesManagerRuleOrder(t *testing.T) { func TestNftablesManagerRuleOrder(t *testing.T) {
// This test verifies rule insertion order in nftables peer ACLs // This test verifies rule insertion order in nftables peer ACLs
// We add accept rule first, then deny rule to test ordering behavior // We add accept rule first, then deny rule to test ordering behavior
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
@@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
stdout, stderr := runIptablesSave(t) stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager") require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))

View File

@@ -16,6 +16,7 @@ import (
"github.com/google/nftables/xt" "github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -32,12 +33,17 @@ const (
chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect" chainNameRoutingRdr = "netbird-rt-redirect"
chainNameForward = "FORWARD" chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif" userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept"
dnatSuffix = "_dnat" dnatSuffix = "_dnat"
snatSuffix = "_snat" snatSuffix = "_snat"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
) )
const refreshRulesMapError = "refresh rules map: %w" const refreshRulesMapError = "refresh rules map: %w"
@@ -63,9 +69,10 @@ type router struct {
wgIface iFaceMapper wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool legacyManagement bool
mtu uint16
} }
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
r := &router{ r := &router{
conn: &nftables.Conn{}, conn: &nftables.Conn{},
workTable: workTable, workTable: workTable,
@@ -73,6 +80,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
rules: make(map[string]*nftables.Rule), rules: make(map[string]*nftables.Rule),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(), ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu,
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
@@ -96,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
func (r *router) init(workTable *nftables.Table) error { func (r *router) init(workTable *nftables.Table) error {
r.workTable = workTable r.workTable = workTable
if err := r.removeAcceptForwardRules(); err != nil { if err := r.removeAcceptFilterRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err) log.Errorf("failed to clean up rules from filter table: %s", err)
} }
if err := r.createContainers(); err != nil { if err := r.createContainers(); err != nil {
@@ -111,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error {
return nil return nil
} }
// Reset cleans existing nftables default forward rules from the system // Reset cleans existing nftables filter table rules from the system
func (r *router) Reset() error { func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller // clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear() r.ipsetCounter.Clear()
var merr *multierror.Error var merr *multierror.Error
if err := r.removeAcceptForwardRules(); err != nil { if err := r.removeAcceptFilterRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
} }
if err := r.removeNatPreroutingRules(); err != nil { if err := r.removeNatPreroutingRules(); err != nil {
@@ -220,11 +228,23 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
}) })
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
Name: chainNameMangleForward,
Table: r.workTable,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
// Add the single NAT rule that matches on mark // Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil { if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err) return fmt.Errorf("add single nat rule: %v", err)
} }
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
if err := r.acceptForwardRules(); err != nil { if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err) log.Errorf("failed to add accept rules for the forward chain: %s", err)
} }
@@ -745,6 +765,83 @@ func (r *router) addPostroutingRules() error {
return nil return nil
} }
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize
exprsOut := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 13,
Len: 1,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 1,
Mask: []byte{0x02},
Xor: []byte{0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00},
},
&expr.Counter{},
&expr.Exthdr{
DestRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
&expr.Cmp{
Op: expr.CmpOpGt,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Exthdr{
SourceRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameMangleForward],
Exprs: exprsOut,
})
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true) sourceExp, err := r.applyNetwork(pair.Source, nil, true)
@@ -840,6 +937,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// that our traffic is not dropped by existing rules there. // that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface. // The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// This method also adds INPUT chain rules to allow traffic to the local interface.
func (r *router) acceptForwardRules() error { func (r *router) acceptForwardRules() error {
if r.filterTable == nil { if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules") log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
@@ -849,7 +947,7 @@ func (r *router) acceptForwardRules() error {
fw := "iptables" fw := "iptables"
defer func() { defer func() {
log.Debugf("Used %s to add accept forward rules", fw) log.Debugf("Used %s to add accept forward and input rules", fw)
}() }()
// Try iptables first and fallback to nftables if iptables is not available // Try iptables first and fallback to nftables if iptables is not available
@@ -859,22 +957,30 @@ func (r *router) acceptForwardRules() error {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables" fw = "nftables"
return r.acceptForwardRulesNftables() return r.acceptFilterRulesNftables()
} }
return r.acceptForwardRulesIptables(ipt) return r.acceptFilterRulesIptables(ipt)
} }
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error { func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() { for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err)) merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
} else { } else {
log.Debugf("added iptables rule: %v", rule) log.Debugf("added iptables forward rule: %v", rule)
} }
} }
inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
} else {
log.Debugf("added iptables input rule: %v", inputRule)
}
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
@@ -886,10 +992,13 @@ func (r *router) getAcceptForwardRules() [][]string {
} }
} }
func (r *router) acceptForwardRulesNftables() error { func (r *router) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
}
func (r *router) acceptFilterRulesNftables() error {
intf := ifname(r.wgIface.Name()) intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{ iifRule := &nftables.Rule{
Table: r.filterTable, Table: r.filterTable,
Chain: &nftables.Chain{ Chain: &nftables.Chain{
@@ -922,11 +1031,10 @@ func (r *router) acceptForwardRulesNftables() error {
}, },
} }
// Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{ oifRule := &nftables.Rule{
Table: r.filterTable, Table: r.filterTable,
Chain: &nftables.Chain{ Chain: &nftables.Chain{
Name: "FORWARD", Name: chainNameForward,
Table: r.filterTable, Table: r.filterTable,
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward, Hooknum: nftables.ChainHookForward,
@@ -935,35 +1043,60 @@ func (r *router) acceptForwardRulesNftables() error {
Exprs: append(oifExprs, getEstablishedExprs(2)...), Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif), UserData: []byte(userDataAcceptForwardRuleOif),
} }
r.conn.InsertRule(oifRule) r.conn.InsertRule(oifRule)
inputRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameInput,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptInputRule),
}
r.conn.InsertRule(inputRule)
return nil return nil
} }
func (r *router) removeAcceptForwardRules() error { func (r *router) removeAcceptFilterRules() error {
if r.filterTable == nil { if r.filterTable == nil {
return nil return nil
} }
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New() ipt, err := iptables.New()
if err != nil { if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptForwardRulesNftables() return r.removeAcceptFilterRulesNftables()
} }
return r.removeAcceptForwardRulesIptables(ipt) return r.removeAcceptFilterRulesIptables(ipt)
} }
func (r *router) removeAcceptForwardRulesNftables() error { func (r *router) removeAcceptFilterRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {
return fmt.Errorf("list chains: %v", err) return fmt.Errorf("list chains: %v", err)
} }
for _, chain := range chains { for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != chainNameForward && chain.Name != chainNameInput {
continue continue
} }
@@ -974,7 +1107,8 @@ func (r *router) removeAcceptForwardRulesNftables() error {
for _, rule := range rules { for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil { if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err) return fmt.Errorf("delete rule: %v", err)
} }
@@ -989,14 +1123,20 @@ func (r *router) removeAcceptForwardRulesNftables() error {
return nil return nil
} }
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error { func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() { for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err)) merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
} }
} }
inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
}
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
@@ -1350,6 +1490,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil return nil
} }
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := protoToInt(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 3,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(sourcePort),
},
}
exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(targetPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
if rule, exists := r.rules[ruleID]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets // applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork( func (r *router) applyNetwork(
network firewall.Network, network firewall.Network,

View File

@@ -17,6 +17,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
) )
const ( const (
@@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
for _, testCase := range test.InsertRuleTestCases { for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present // need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}) })
@@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
for _, testCase := range test.RemoveRuleTestCases { for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}) })
@@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable() defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock) r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router") require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable)) require.NoError(t, r.init(workTable))
@@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable() defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock) r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router") require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable)) require.NoError(t, r.init(workTable))

View File

@@ -3,6 +3,7 @@ package nftables
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -10,6 +11,7 @@ type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
@@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string {
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState) mtu := s.InterfaceState.MTU
if mtu == 0 {
mtu = iface.DefaultMTU
}
nft, err := Create(s.InterfaceState, mtu)
if err != nil { if err != nil {
return fmt.Errorf("create nftables manager: %w", err) return fmt.Errorf("create nftables manager: %w", err)
} }

View File

@@ -22,6 +22,8 @@ type BaseConnTrack struct {
PacketsRx atomic.Uint64 PacketsRx atomic.Uint64
BytesTx atomic.Uint64 BytesTx atomic.Uint64
BytesRx atomic.Uint64 BytesRx atomic.Uint64
DNATOrigPort atomic.Uint32
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler

View File

@@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker return tracker
} }
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
key := ConnKey{ key := ConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
if exists { if exists {
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction, size)
return key, true return key, uint16(conn.DNATOrigPort.Load()), true
} }
return key, false return key, 0, false
} }
// TrackOutbound records an outbound TCP connection // TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists {
// if (inverted direction) conn is not tracked, track this direction return origPort
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
} }
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0)
return 0
} }
// TrackInbound processes an inbound TCP packet and updates connection state // TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort)
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists || flags&TCPSyn == 0 { if exists || flags&TCPSyn == 0 {
return return
} }
@@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.tombstone.Store(false) conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew)) conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort))
t.logger.Trace2("New %s TCP connection: %s", direction, key) if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s TCP connection: %s", direction, key)
}
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction, size)
t.mutex.Lock() t.mutex.Lock()
@@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() {
} }
} }
// GetConnection safely retrieves a connection state
func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key]
return conn, exists
}
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() { func (t *TCPTracker) Close() {
t.tickerCancel() t.tickerCancel()

View File

@@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
serverPort := uint16(80) serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound) // 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
key := ConnKey{ key := ConnKey{
SrcIP: clientIP, SrcIP: clientIP,
@@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake // 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer // 4. Test data transfer
// Client sends data // Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0)
// Server sends ACK for data // Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
@@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data // Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
// Verify state and counters // Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState()) require.Equal(t, TCPStateEstablished, conn.GetState())

View File

@@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker return tracker
} }
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size)
// if (inverted direction) conn is not tracked, track this direction if exists {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size) return origPort
} }
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0)
return 0
} }
// TrackInbound records an inbound UDP connection // TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) { func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort)
} }
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) {
key := ConnKey{ key := ConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
if exists { if exists {
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size) conn.UpdateCounters(direction, size)
return key, true return key, uint16(conn.DNATOrigPort.Load()), true
} }
return key, false return key, 0, false
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) { func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists { if exists {
return return
} }
@@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
} }
conn.DNATOrigPort.Store(uint32(origPort))
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size) conn.UpdateCounters(direction, size)
@@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace2("New %s UDP connection: %s", direction, key) if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else {
t.logger.Trace2("New %s UDP connection: %s", direction, key)
}
t.sendEvent(nftypes.TypeStart, conn, ruleID) t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }

View File

@@ -1,6 +1,7 @@
package uspfilter package uspfilter
import ( import (
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@@ -27,7 +28,12 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
const layerTypeAll = 0 const (
layerTypeAll = 0
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
const ( const (
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed. // EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
@@ -36,6 +42,9 @@ const (
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped. // EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING" EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
// EnvForceUserspaceRouter forces userspace routing even if native routing is available. // EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
@@ -50,6 +59,12 @@ const (
var errNatNotSupported = errors.New("nat not supported with userspace firewall") var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// serviceKey represents a protocol/port combination for netstack service registry
type serviceKey struct {
protocol gopacket.LayerType
port uint16
}
// RuleSet is a set of rules grouped by a string key // RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule type RuleSet map[string]PeerRule
@@ -109,6 +124,17 @@ type Manager struct {
dnatMappings map[netip.Addr]netip.Addr dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex dnatMutex sync.RWMutex
dnatBiMap *biDNATMap dnatBiMap *biDNATMap
portDNATEnabled atomic.Bool
portDNATRules []portDNATRule
portDNATMutex sync.RWMutex
netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex
mtu uint16
mssClampValue uint16
mssClampEnabled bool
} }
// decoder for packages // decoder for packages
@@ -122,19 +148,21 @@ type decoder struct {
icmp6 layers.ICMPv6 icmp6 layers.ICMPv6
decoded []gopacket.LayerType decoded []gopacket.LayerType
parser *gopacket.DecodingLayerParser parser *gopacket.DecodingLayerParser
dnatOrigPort uint16
} }
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger) return create(iface, nil, disableServerRoutes, flowLogger, mtu)
} }
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
if nativeFirewall == nil { if nativeFirewall == nil {
return nil, errors.New("native firewall is nil") return nil, errors.New("native firewall is nil")
} }
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger) mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -142,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
return mgr, nil return mgr, nil
} }
func parseCreateEnv() (bool, bool) { func parseCreateEnv() (bool, bool, bool) {
var disableConntrack, enableLocalForwarding bool var disableConntrack, enableLocalForwarding, disableMSSClamping bool
var err error var err error
if val := os.Getenv(EnvDisableConntrack); val != "" { if val := os.Getenv(EnvDisableConntrack); val != "" {
disableConntrack, err = strconv.ParseBool(val) disableConntrack, err = strconv.ParseBool(val)
@@ -162,12 +190,18 @@ func parseCreateEnv() (bool, bool) {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err) log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
} }
} }
if val := os.Getenv(EnvDisableMSSClamping); val != "" {
disableMSSClamping, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err)
}
}
return disableConntrack, enableLocalForwarding return disableConntrack, enableLocalForwarding, disableMSSClamping
} }
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
disableConntrack, enableLocalForwarding := parseCreateEnv() disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
m := &Manager{ m := &Manager{
decoders: sync.Pool{ decoders: sync.Pool{
@@ -196,13 +230,19 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
netstack: netstack.IsEnabled(), netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding, localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr), dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
mtu: mtu,
} }
m.routingEnabled.Store(false) m.routingEnabled.Store(false)
if !disableMSSClamping {
m.mssClampEnabled = true
m.mssClampValue = mtu - ipTCPHeaderMinSize
}
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err) return nil, fmt.Errorf("update local IPs: %w", err)
} }
if disableConntrack { if disableConntrack {
log.Info("conntrack is disabled") log.Info("conntrack is disabled")
} else { } else {
@@ -210,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
} }
// netstack needs the forwarder for local traffic
if m.netstack && m.localForwarding { if m.netstack && m.localForwarding {
if err := m.initForwarder(); err != nil { if err := m.initForwarder(); err != nil {
log.Errorf("failed to initialize forwarder: %v", err) log.Errorf("failed to initialize forwarder: %v", err)
} }
} }
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err) return nil, fmt.Errorf("set filter: %w", err)
} }
@@ -320,7 +357,7 @@ func (m *Manager) initForwarder() error {
return errors.New("forwarding not supported") return errors.New("forwarding not supported")
} }
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack) forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu)
if err != nil { if err != nil {
m.routingEnabled.Store(false) m.routingEnabled.Store(false)
return fmt.Errorf("create forwarder: %w", err) return fmt.Errorf("create forwarder: %w", err)
@@ -626,11 +663,20 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return false return false
} }
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) { switch d.decoded[1] {
return true case layers.LayerTypeUDP:
if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
return true
}
case layers.LayerTypeTCP:
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
m.clampTCPMSS(packetData, d)
}
} }
m.trackOutbound(d, srcIP, dstIP, size) m.trackOutbound(d, srcIP, dstIP, packetData, size)
m.translateOutboundDNAT(packetData, d) m.translateOutboundDNAT(packetData, d)
return false return false
@@ -674,14 +720,117 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags return flags
} }
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { // clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation.
// Both sides advertise their MSS during connection establishment, so we need to clamp both.
func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
if !d.tcp.SYN {
return false
}
if len(d.tcp.Options) == 0 {
return false
}
mssOptionIndex := -1
var currentMSS uint16
for i, opt := range d.tcp.Options {
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
if currentMSS > m.mssClampValue {
mssOptionIndex = i
break
}
}
}
if mssOptionIndex == -1 {
return false
}
ipHeaderSize := int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
return false
}
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
return true
}
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
tcpHeaderStart := ipHeaderSize
tcpOptionsStart := tcpHeaderStart + 20
optOffset := tcpOptionsStart
for j := 0; j < mssOptionIndex; j++ {
switch d.tcp.Options[j].OptionType {
case layers.TCPOptionKindEndList, layers.TCPOptionKindNop:
optOffset++
default:
optOffset += 2 + len(d.tcp.Options[j].OptionData)
}
}
mssValueOffset := optOffset + 2
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
return true
}
func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) {
tcpLayer := packetData[tcpHeaderStart:]
tcpLength := len(packetData) - tcpHeaderStart
tcpLayer[16] = 0
tcpLayer[17] = 0
var pseudoSum uint32
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
var sum uint32 = pseudoSum
for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
}
if tcpLength%2 == 1 {
sum += uint32(tcpLayer[tcpLength-1]) << 8
}
for sum > 0xFFFF {
sum = (sum & 0xFFFF) + (sum >> 16)
}
checksum := ^uint16(sum)
binary.BigEndian.PutUint16(tcpLayer[16:18], checksum)
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
if origPort == 0 {
break
}
if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite UDP port: %v", err)
}
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
if origPort == 0 {
break
}
if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil {
m.logger.Error1("failed to rewrite TCP port: %v", err)
}
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
} }
@@ -691,13 +840,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size) m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort)
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
} }
d.dnatOrigPort = 0
} }
// udpHooksDrop checks if any UDP hooks should drop the packet // udpHooksDrop checks if any UDP hooks should drop the packet
@@ -759,10 +910,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
return false return false
} }
// TODO: optimize port DNAT by caching matched rules in conntrack
if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated {
// Re-decode after port DNAT translation to update port information
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("failed to re-decode packet after port DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
if translated := m.translateInboundReverse(packetData, d); translated { if translated := m.translateInboundReverse(packetData, d); translated {
// Re-decode after translation to get original addresses // Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err) m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err)
return true return true
} }
srcIP, dstIP = m.extractIPs(d) srcIP, dstIP = m.extractIPs(d)
@@ -807,9 +968,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return true return true
} }
// If requested we pass local traffic to internal interfaces to the forwarder. if m.shouldForward(d, dstIP) {
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
return m.handleForwardedLocalTraffic(packetData) return m.handleForwardedLocalTraffic(packetData)
} }
@@ -1243,3 +1402,86 @@ func (m *Manager) DisableRouting() error {
return nil return nil
} }
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) {
m.netstackServiceMutex.Lock()
defer m.netstackServiceMutex.Unlock()
layerType := m.protocolToLayerType(protocol)
key := serviceKey{protocol: layerType, port: port}
m.netstackServices[key] = struct{}{}
m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType)
m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices))
}
// UnregisterNetstackService removes a service from the netstack registry
func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) {
m.netstackServiceMutex.Lock()
defer m.netstackServiceMutex.Unlock()
layerType := m.protocolToLayerType(protocol)
key := serviceKey{protocol: layerType, port: port}
delete(m.netstackServices, key)
m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port)
}
// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use
func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType {
switch protocol {
case nftypes.TCP:
return layers.LayerTypeTCP
case nftypes.UDP:
return layers.LayerTypeUDP
case nftypes.ICMP:
return layers.LayerTypeICMPv4
default:
return gopacket.LayerType(0) // Invalid/unknown
}
}
// shouldForward determines if a packet should be forwarded to the forwarder.
// The forwarder handles routing packets to the native OS network stack.
// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly.
func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
// not enabled, never forward
if !m.localForwarding {
return false
}
// netstack always needs to forward because it's lacking a native interface
// exception for registered netstack services, those should go to netstack listeners
if m.netstack {
return !m.hasMatchingNetstackService(d)
}
// traffic to our other local interfaces (not NetBird IP) - always forward
if dstIP != m.wgIface.Address().IP {
return true
}
// traffic to our NetBird IP, not netstack mode - send to netstack listeners
return false
}
// hasMatchingNetstackService checks if there's a registered netstack service for this packet
func (m *Manager) hasMatchingNetstackService(d *decoder) bool {
if len(d.decoded) < 2 {
return false
}
var dstPort uint16
switch d.decoded[1] {
case layers.LayerTypeTCP:
dstPort = uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
dstPort = uint16(d.udp.DstPort)
default:
return false
}
key := serviceKey{protocol: d.decoded[1], port: dstPort}
m.netstackServiceMutex.RLock()
_, exists := m.netstackServices[key]
m.netstackServiceMutex.RUnlock()
return exists
}

View File

@@ -17,6 +17,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup // Create manager and basic setup
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
} }
} }
// generateTCPPacketWithFlags creates a TCP packet with specific flags
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
b.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
}
// Set TCP flags
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
return buf.Bytes()
}
func BenchmarkRouteACLs(b *testing.B) { func BenchmarkRouteACLs(b *testing.B) {
manager := setupRoutedManager(b, "10.10.0.100/16") manager := setupRoutedManager(b, "10.10.0.100/16")
@@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) {
} }
} }
} }
// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound.
// This shows the overhead difference between the common case (non-SYN packets, fast path)
// and the rare case (SYN packets that need clamping, expensive path).
func BenchmarkMSSClamping(b *testing.B) {
scenarios := []struct {
name string
description string
genPacket func(*testing.B, net.IP, net.IP) []byte
frequency string
}{
{
name: "syn_needs_clamp",
description: "SYN packet needing MSS clamping",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
frequency: "~0.1% of traffic - EXPENSIVE",
},
{
name: "syn_no_clamp_needed",
description: "SYN packet with already-small MSS",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200)
},
frequency: "~0.05% of traffic",
},
{
name: "tcp_ack",
description: "Non-SYN TCP packet (ACK, data transfer)",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
frequency: "~60-70% of traffic - FAST PATH",
},
{
name: "tcp_psh_ack",
description: "TCP data packet (PSH+ACK)",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
},
frequency: "~10-20% of traffic - FAST PATH",
},
{
name: "udp",
description: "UDP packet",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
},
frequency: "~20-30% of traffic - FAST PATH",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = true
manager.mssClampValue = 1240
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled
// for the common case (non-SYN TCP packets).
func BenchmarkMSSClampingOverhead(b *testing.B) {
scenarios := []struct {
name string
enabled bool
genPacket func(*testing.B, net.IP, net.IP) []byte
}{
{
name: "disabled_tcp_ack",
enabled: false,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "enabled_tcp_ack",
enabled: true,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "disabled_syn_needs_clamp",
enabled: false,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
{
name: "enabled_syn_needs_clamp",
enabled: true,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = sc.enabled
if sc.enabled {
manager.mssClampValue = 1240
}
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases
func BenchmarkMSSClampingMemory(b *testing.B) {
scenarios := []struct {
name string
genPacket func(*testing.B, net.IP, net.IP) []byte
}{
{
name: "tcp_ack_fast_path",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "syn_needs_clamp",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
{
name: "udp_fast_path",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
},
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = true
manager.mssClampValue = 1240
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte {
b.Helper()
ip := &layers.IPv4{
Version: 4,
IHL: 5,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
Seq: 1000,
Window: 65535,
}
require.NoError(b, tcp.SetNetworkLayerForChecksum(ip))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{})))
return buf.Bytes()
}

View File

@@ -12,6 +12,7 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
}, },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, manager) require.NotNil(t, manager)
@@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
}, },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(tb, err) require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting()) require.NoError(tb, manager.EnableRouting())
require.NotNil(tb, manager) require.NotNil(tb, manager)
@@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) {
}, },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))

View File

@@ -1,6 +1,7 @@
package uspfilter package uspfilter
import ( import (
"encoding/binary"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -17,6 +18,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -66,7 +68,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -86,7 +88,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -119,7 +121,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -215,7 +217,7 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
@@ -265,7 +267,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -304,7 +306,7 @@ func TestNotMatchByIP(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -367,7 +369,7 @@ func TestRemovePacketHook(t *testing.T) {
} }
// creating manager instance // creating manager instance
manager, err := Create(iface, false, flowLogger) manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
if err != nil { if err != nil {
t.Fatalf("Failed to create Manager: %s", err) t.Fatalf("Failed to create Manager: %s", err)
} }
@@ -413,7 +415,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
manager.udpTracker.Close() manager.udpTracker.Close()
@@ -495,7 +497,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -522,7 +524,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
@@ -729,7 +731,7 @@ func TestUpdateSetMerge(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
@@ -815,7 +817,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
@@ -923,3 +925,192 @@ func TestUpdateSetDeduplication(t *testing.T) {
require.Equal(t, tc.expected, isAllowed, tc.desc) require.Equal(t, tc.expected, isAllowed, tc.desc)
} }
} }
func TestMSSClamping(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
}
},
}
manager, err := Create(ifaceMock, false, flowLogger, 1280)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
err = manager.UpdateLocalIPs()
require.NoError(t, err)
srcIP := net.ParseIP("100.10.0.2")
dstIP := net.ParseIP("8.8.8.8")
t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) {
highMSS := uint16(1460)
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
})
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
lowMSS := uint16(1200)
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified")
})
t.Run("SYN-ACK packet gets clamped", func(t *testing.T) {
highMSS := uint16(1460)
packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
})
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck))
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Empty(t, d.tcp.Options, "ACK packet should have no options")
})
}
func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionLength: 4,
OptionData: binary.BigEndian.AppendUint16(nil, mss),
},
},
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}
func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
ACK: true,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionLength: 4,
OptionData: binary.BigEndian.AppendUint16(nil, mss),
},
},
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}
func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Window: 65535,
}
if flags&uint16(conntrack.TCPSyn) != 0 {
tcpLayer.SYN = true
}
if flags&uint16(conntrack.TCPAck) != 0 {
tcpLayer.ACK = true
}
if flags&uint16(conntrack.TCPFin) != 0 {
tcpLayer.FIN = true
}
if flags&uint16(conntrack.TCPRst) != 0 {
tcpLayer.RST = true
}
if flags&uint16(conntrack.TCPPush) != 0 {
tcpLayer.PSH = true
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}

View File

@@ -45,7 +45,7 @@ type Forwarder struct {
netstack bool netstack bool
} }
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) { func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
s := stack.New(stack.Options{ s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{ TransportProtocols: []stack.TransportProtocolFactory{
@@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
HandleLocal: false, HandleLocal: false,
}) })
mtu, err := iface.GetDevice().MTU()
if err != nil {
return nil, fmt.Errorf("get MTU: %w", err)
}
nicID := tcpip.NICID(1) nicID := tcpip.NICID(1)
endpoint := &endpoint{ endpoint := &endpoint{
logger: logger, logger: logger,
@@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
} }
if err := s.CreateNIC(nicID, endpoint); err != nil { if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("failed to create NIC: %v", err) return nil, fmt.Errorf("create NIC: %v", err)
} }
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{

View File

@@ -49,7 +49,7 @@ type idleConn struct {
conn *udpPacketConn conn *udpPacketConn
} }
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{ f := &udpForwarder{
logger: logger, logger: logger,

View File

@@ -50,6 +50,8 @@ type logMessage struct {
arg4 any arg4 any
arg5 any arg5 any
arg6 any arg6 any
arg7 any
arg8 any
} }
// Logger is a high-performance, non-blocking logger // Logger is a high-performance, non-blocking logger
@@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
} }
func (l *Logger) Error(format string) { func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
select { select {
@@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
} }
} }
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) {
select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default:
}
}
}
func (l *Logger) Trace1(format string, arg1 any) { func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
@@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
} }
} }
// Trace8 logs a trace message with 8 arguments (8 placeholder in format string)
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) {
select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default:
}
}
}
func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = (*buf)[:0] *buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
@@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
argCount++ argCount++
if msg.arg6 != nil { if msg.arg6 != nil {
argCount++ argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
} }
} }
} }
@@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5) formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5)
case 6: case 6:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6) formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6)
case 7:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7)
case 8:
formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8)
} }
*buf = append(*buf, formatted...) *buf = append(*buf, formatted...)

View File

@@ -5,7 +5,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -13,6 +15,21 @@ import (
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
var (
errInvalidIPHeaderLength = errors.New("invalid IP header length")
)
const (
// Port offsets in TCP/UDP headers
sourcePortOffset = 0
destinationPortOffset = 2
// IP address offsets in IPv4 header
sourceIPOffset = 12
destinationIPOffset = 16
)
// ipv4Checksum calculates IPv4 header checksum.
func ipv4Checksum(header []byte) uint16 { func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 { if len(header) < 20 {
return 0 return 0
@@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 {
return ^uint16(sum) return ^uint16(sum)
} }
// icmpChecksum calculates ICMP checksum.
func icmpChecksum(data []byte) uint16 { func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32 var sum1, sum2, sum3, sum4 uint32
i := 0 i := 0
@@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 {
return ^uint16(sum) return ^uint16(sum)
} }
// biDNATMap maintains bidirectional DNAT mappings.
type biDNATMap struct { type biDNATMap struct {
forward map[netip.Addr]netip.Addr forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr reverse map[netip.Addr]netip.Addr
} }
// portDNATRule represents a port-specific DNAT rule.
type portDNATRule struct {
protocol gopacket.LayerType
origPort uint16
targetPort uint16
targetIP netip.Addr
}
// newBiDNATMap creates a new bidirectional DNAT mapping structure.
func newBiDNATMap() *biDNATMap { func newBiDNATMap() *biDNATMap {
return &biDNATMap{ return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr), forward: make(map[netip.Addr]netip.Addr),
@@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap {
} }
} }
// set adds a bidirectional DNAT mapping between original and translated addresses.
func (b *biDNATMap) set(original, translated netip.Addr) { func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated b.forward[original] = translated
b.reverse[translated] = original b.reverse[translated] = original
} }
// delete removes a bidirectional DNAT mapping for the given original address.
func (b *biDNATMap) delete(original netip.Addr) { func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists { if translated, exists := b.forward[original]; exists {
delete(b.forward, original) delete(b.forward, original)
@@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) {
} }
} }
// getTranslated returns the translated address for a given original address.
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original] translated, exists := b.forward[original]
return translated, exists return translated, exists
} }
// getOriginal returns the original address for a given translated address.
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated] original, exists := b.reverse[translated]
return original, exists return original, exists
} }
// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation.
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() { if !originalAddr.IsValid() {
return fmt.Errorf("invalid IP addresses") return fmt.Errorf("invalid original IP address")
}
if !translatedAddr.IsValid() {
return fmt.Errorf("invalid translated IP address")
} }
if m.localipmanager.IsLocalIP(translatedAddr) { if m.localipmanager.IsLocalIP(translatedAddr) {
@@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
m.dnatMutex.Lock() m.dnatMutex.Lock()
defer m.dnatMutex.Unlock() defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil { if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap() m.dnatBiMap = newBiDNATMap()
@@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr
return nil return nil
} }
// RemoveInternalDNATMapping removes a 1:1 IP address mapping // RemoveInternalDNATMapping removes a 1:1 IP address mapping.
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock() m.dnatMutex.Lock()
defer m.dnatMutex.Unlock() defer m.dnatMutex.Unlock()
@@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
return nil return nil
} }
// getDNATTranslation returns the translated address if a mapping exists // getDNATTranslation returns the translated address if a mapping exists.
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return addr, false return addr, false
@@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
return translated, exists return translated, exists
} }
// findReverseDNATMapping finds original address for return traffic // findReverseDNATMapping finds original address for return traffic.
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return translatedAddr, false return translatedAddr, false
@@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr,
return original, exists return original, exists
} }
// translateOutboundDNAT applies DNAT translation to outbound packets // translateOutboundDNAT applies DNAT translation to outbound packets.
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return false return false
} }
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP) translatedIP, exists := m.getDNATTranslation(dstIP)
@@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return false return false
} }
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
m.logger.Error1("Failed to rewrite packet destination: %v", err) m.logger.Error1("failed to rewrite packet destination: %v", err)
return false return false
} }
@@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
return true return true
} }
// translateInboundReverse applies reverse DNAT to inbound return traffic // translateInboundReverse applies reverse DNAT to inbound return traffic.
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() { if !m.dnatEnabled.Load() {
return false return false
} }
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP) originalIP, exists := m.findReverseDNATMapping(srcIP)
@@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return false return false
} }
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
m.logger.Error1("Failed to rewrite packet source: %v", err) m.logger.Error1("failed to rewrite packet source: %v", err)
return false return false
} }
@@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
return true return true
} }
// rewritePacketDestination replaces destination IP in the packet // rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums.
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { if !newIP.Is4() {
return ErrIPv4Only return ErrIPv4Only
} }
var oldDst [4]byte var oldIP [4]byte
copy(oldDst[:], packetData[16:20]) copy(oldIP[:], packetData[ipOffset:ipOffset+4])
newDst := newIP.As4() newIPBytes := newIP.As4()
copy(packetData[16:20], newDst[:]) copy(packetData[ipOffset:ipOffset+4], newIPBytes[:])
ipHeaderLen := int(d.ip4.IHL) * 4 ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length") return errInvalidIPHeaderLength
} }
binary.BigEndian.PutUint16(packetData[10:12], 0) binary.BigEndian.PutUint16(packetData[10:12], 0)
@@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP
if len(d.decoded) > 1 { if len(d.decoded) > 1 {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen) m.updateICMPChecksum(packetData, ipHeaderLen)
} }
@@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip
return nil return nil
} }
// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624.
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 { if len(packetData) < tcpStart+18 {
@@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
} }
// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624.
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen udpStart := ipHeaderLen
if len(packetData) < udpStart+8 { if len(packetData) < udpStart+8 {
@@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
} }
// updateICMPChecksum recalculates ICMP checksum after packet modification.
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 { if len(packetData) < icmpStart+8 {
@@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
binary.BigEndian.PutUint16(icmpData[2:4], checksum) binary.BigEndian.PutUint16(icmpData[2:4], checksum)
} }
// incrementalUpdate performs incremental checksum update per RFC 1624 // incrementalUpdate performs incremental checksum update per RFC 1624.
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum) sum := uint32(^oldChecksum)
@@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
return ^uint16(sum) return ^uint16(sum)
} }
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding) // AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return nil, errNatNotSupported return nil, errNatNotSupported
@@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
return m.nativeFirewall.AddDNATRule(rule) return m.nativeFirewall.AddDNATRule(rule)
} }
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall) // DeleteDNATRule deletes outbound DNAT rule.
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errNatNotSupported return errNatNotSupported
} }
return m.nativeFirewall.DeleteDNATRule(rule) return m.nativeFirewall.DeleteDNATRule(rule)
} }
// addPortRedirection adds a port redirection rule.
func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
rule := portDNATRule{
protocol: protocol,
origPort: sourcePort,
targetPort: targetPort,
targetIP: targetIP,
}
m.portDNATRules = append(m.portDNATRules, rule)
m.portDNATEnabled.Store(true)
return nil
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// removePortRedirection removes a port redirection rule.
func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error {
m.portDNATMutex.Lock()
defer m.portDNATMutex.Unlock()
m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool {
return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0
})
if len(m.portDNATRules) == 0 {
m.portDNATEnabled.Store(false)
}
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
case firewall.ProtocolTCP:
layerType = layers.LayerTypeTCP
case firewall.ProtocolUDP:
layerType = layers.LayerTypeUDP
default:
return fmt.Errorf("unsupported protocol: %s", protocol)
}
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {
return false
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
dstPort := uint16(d.tcp.DstPort)
return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort)
case layers.LayerTypeUDP:
dstPort := uint16(d.udp.DstPort)
return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort)
default:
return false
}
}
type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error
func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool {
m.portDNATMutex.RLock()
defer m.portDNATMutex.RUnlock()
for _, rule := range m.portDNATRules {
if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 {
continue
}
if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 {
return false
}
if rule.origPort != port {
continue
}
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
m.logger.Error1("failed to rewrite port: %v", err)
return false
}
d.dnatOrigPort = rule.origPort
return true
}
return false
}
// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum.
func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+4 {
return fmt.Errorf("packet too short for TCP header")
}
portStart := tcpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
if len(packetData) >= tcpStart+18 {
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
return nil
}
// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum.
func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error {
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return errInvalidIPHeaderLength
}
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return fmt.Errorf("packet too short for UDP header")
}
portStart := udpStart + portOffset
oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2])
binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort)
checksumOffset := udpStart + 6
if len(packetData) >= udpStart+8 {
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum != 0 {
var oldPortBytes, newPortBytes [2]byte
binary.BigEndian.PutUint16(oldPortBytes[:], oldPort)
binary.BigEndian.PutUint16(newPortBytes[:], newPort)
newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:])
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
}
return nil
}

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err) require.NoError(b, err)
defer func() { defer func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
@@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
func BenchmarkDNATConcurrency(b *testing.B) { func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err) require.NoError(b, err)
defer func() { defer func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
@@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) { b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err) require.NoError(b, err)
defer func() { defer func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
@@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
func BenchmarkDNATMemoryAllocations(b *testing.B) { func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err) require.NoError(b, err)
defer func() { defer func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
@@ -414,3 +415,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) {
} }
}) })
} }
// BenchmarkPortDNAT measures the performance of port DNAT operations
func BenchmarkPortDNAT(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
useMatchPort bool
description string
}{
{
name: "tcp_inbound_dnat_match",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: true,
description: "TCP inbound port DNAT translation (22 → 22022)",
},
{
name: "tcp_inbound_dnat_nomatch",
proto: layers.IPProtocolTCP,
setupDNAT: true,
useMatchPort: false,
description: "TCP inbound with DNAT configured but no port match",
},
{
name: "tcp_inbound_no_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
useMatchPort: false,
description: "TCP inbound without DNAT (baseline)",
},
{
name: "udp_inbound_dnat_match",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: true,
description: "UDP inbound port DNAT translation (5353 → 22054)",
},
{
name: "udp_inbound_dnat_nomatch",
proto: layers.IPProtocolUDP,
setupDNAT: true,
useMatchPort: false,
description: "UDP inbound with DNAT configured but no port match",
},
{
name: "udp_inbound_no_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
useMatchPort: false,
description: "UDP inbound without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
var origPort, targetPort, testPort uint16
if sc.proto == layers.IPProtocolTCP {
origPort, targetPort = 22, 22022
} else {
origPort, targetPort = 5353, 22054
}
if sc.useMatchPort {
testPort = origPort
} else {
testPort = 443 // Different port
}
// Setup port DNAT mapping if needed
if sc.setupDNAT {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort)
require.NoError(b, err)
}
// Pre-establish inbound connection for outbound reverse test
if sc.setupDNAT && sc.useMatchPort {
inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort)
manager.filterInbound(inboundPacket, 0)
}
b.ResetTimer()
b.ReportAllocs()
// Benchmark inbound DNAT translation
b.Run("inbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time
packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort)
manager.filterInbound(packet, 0)
}
})
// Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches)
if sc.setupDNAT && sc.useMatchPort {
b.Run("outbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh return packet (from target port)
packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321)
manager.filterOutbound(packet, 0)
}
})
}
})
}
}

View File

@@ -8,6 +8,8 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -15,7 +17,7 @@ import (
func TestDNATTranslationCorrectness(t *testing.T) { func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
@@ -99,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
func TestDNATMappingManagement(t *testing.T) { func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
@@ -143,3 +145,111 @@ func TestDNATMappingManagement(t *testing.T) {
err = manager.RemoveInternalDNATMapping(originalIP) err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping") require.Error(t, err, "Should error when removing non-existent mapping")
} }
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
testCases := []struct {
name string
protocol layers.IPProtocol
sourcePort uint16
targetPort uint16
}{
{"TCP SSH", layers.IPProtocolTCP, 22, 22022},
{"UDP DNS", layers.IPProtocolUDP, 5353, 22054},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort)
d := parsePacket(t, inboundPacket)
translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr)
require.True(t, translated, "Inbound packet should be translated")
d = parsePacket(t, inboundPacket)
var dstPort uint16
switch tc.protocol {
case layers.IPProtocolTCP:
dstPort = uint16(d.tcp.DstPort)
case layers.IPProtocolUDP:
dstPort = uint16(d.udp.DstPort)
}
require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port")
err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort)
require.NoError(t, err)
})
}
}
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
localAddr := netip.MustParseAddr("100.0.2.175")
clientIP := netip.MustParseAddr("100.0.169.249")
err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcIP netip.Addr
dstIP netip.Addr
srcPort uint16
dstPort uint16
}{
{"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80},
{"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22},
{"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22},
{"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort)
d := parsePacket(t, packet)
translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP)
require.False(t, translated, "Packet should NOT be translated for %s", tc.name)
d = parsePacket(t, packet)
if tc.protocol == layers.IPProtocolTCP {
require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged")
} else if tc.protocol == layers.IPProtocolUDP {
require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged")
}
})
}
}
func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol {
switch proto {
case layers.IPProtocolTCP:
return firewall.ProtocolTCP
case layers.IPProtocolUDP:
return firewall.ProtocolUDP
default:
return firewall.ProtocolALL
}
}

View File

@@ -16,25 +16,33 @@ type PacketStage int
const ( const (
StageReceived PacketStage = iota StageReceived PacketStage = iota
StageInboundPortDNAT
StageInbound1to1NAT
StageConntrack StageConntrack
StagePeerACL StagePeerACL
StageRouting StageRouting
StageRouteACL StageRouteACL
StageForwarding StageForwarding
StageCompleted StageCompleted
StageOutbound1to1NAT
StageOutboundPortReverse
) )
const msgProcessingCompleted = "Processing completed" const msgProcessingCompleted = "Processing completed"
func (s PacketStage) String() string { func (s PacketStage) String() string {
return map[PacketStage]string{ return map[PacketStage]string{
StageReceived: "Received", StageReceived: "Received",
StageConntrack: "Connection Tracking", StageInboundPortDNAT: "Inbound Port DNAT",
StagePeerACL: "Peer ACL", StageInbound1to1NAT: "Inbound 1:1 NAT",
StageRouting: "Routing", StageConntrack: "Connection Tracking",
StageRouteACL: "Route ACL", StagePeerACL: "Peer ACL",
StageForwarding: "Forwarding", StageRouting: "Routing",
StageCompleted: "Completed", StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
StageOutbound1to1NAT: "Outbound 1:1 NAT",
StageOutboundPortReverse: "Outbound DNAT Reverse",
}[s] }[s]
} }
@@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
} }
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) {
return trace
}
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace return trace
} }
@@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
} }
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageCompleted, "Packet dropped - decode error", false)
return trace
}
m.handleOutboundDNAT(trace, packetData, d)
dropped := m.filterOutbound(packetData, 0) dropped := m.filterOutbound(packetData, 0)
if dropped { if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
@@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr
} }
return trace return trace
} }
func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool {
portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d)
if portDNATApplied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
trace.DestinationPort = m.getDestPort(d)
}
nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d)
if nat1to1Applied {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false)
return true
}
*srcIP, *dstIP = m.extractIPs(d)
}
return false
}
func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true)
return false
}
protocol := d.decoded[1]
if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP {
trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var originalPort uint16
if protocol == layers.LayerTypeTCP {
originalPort = uint16(d.tcp.DstPort)
} else {
originalPort = uint16(d.udp.DstPort)
}
translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP)
if translated {
ipHeaderLen := int((packetData[0] & 0x0F) * 4)
translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3])
protoStr := "TCP"
if protocol == layers.LayerTypeUDP {
protoStr = "UDP"
}
msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort)
trace.AddResult(StageInboundPortDNAT, msg, true)
return true
}
trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true)
return false
}
func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
translated := m.translateInboundReverse(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatBiMap.getOriginal(srcIP)
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP)
trace.AddResult(StageInbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) {
m.traceOutbound1to1NAT(trace, packetData, d)
m.traceOutboundPortReverse(trace, packetData, d)
}
func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true)
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translated := m.translateOutboundDNAT(packetData, d)
if translated {
m.dnatMutex.RLock()
translatedIP, exists := m.dnatMappings[dstIP]
m.dnatMutex.RUnlock()
if exists {
msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP)
trace.AddResult(StageOutbound1to1NAT, msg, true)
return true
}
}
trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true)
return false
}
func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool {
if !m.portDNATEnabled.Load() {
trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true)
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true)
return false
}
if len(d.decoded) < 2 {
trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true)
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
var origPort uint16
transport := d.decoded[1]
switch transport {
case layers.LayerTypeTCP:
srcPort := uint16(d.tcp.SrcPort)
dstPort := uint16(d.tcp.DstPort)
conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
case layers.LayerTypeUDP:
srcPort := uint16(d.udp.SrcPort)
dstPort := uint16(d.udp.DstPort)
conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort)
if exists {
origPort = uint16(conn.DNATOrigPort.Load())
}
if origPort != 0 {
msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort)
trace.AddResult(StageOutboundPortReverse, msg, true)
return true
}
default:
trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true)
return false
}
trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true)
return false
}
func (m *Manager) getDestPort(d *decoder) uint16 {
if len(d.decoded) < 2 {
return 0
}
switch d.decoded[1] {
case layers.LayerTypeTCP:
return uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
return uint16(d.udp.DstPort)
default:
return 0
}
}

View File

@@ -10,6 +10,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
if !statefulMode { if !statefulMode {
@@ -104,6 +105,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -126,6 +129,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -153,6 +158,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -179,6 +186,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -204,6 +213,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageRouteACL, StageRouteACL,
@@ -228,6 +239,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageRouteACL, StageRouteACL,
@@ -246,6 +259,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageRouteACL, StageRouteACL,
@@ -264,6 +279,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StageCompleted, StageCompleted,
@@ -287,6 +304,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageCompleted, StageCompleted,
}, },
@@ -301,6 +320,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageCompleted, StageCompleted,
}, },
expectedAllow: true, expectedAllow: true,
@@ -319,6 +340,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -340,6 +363,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -362,6 +387,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -382,6 +409,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack, StageConntrack,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
@@ -406,6 +435,8 @@ func TestTracePacket(t *testing.T) {
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageRouting, StageRouting,
StagePeerACL, StagePeerACL,
StageCompleted, StageCompleted,

View File

@@ -4,12 +4,15 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"fmt"
"runtime" "runtime"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
@@ -17,6 +20,9 @@ import (
"github.com/netbirdio/netbird/util/embeddedroots" "github.com/netbirdio/netbird/util/embeddedroots"
) )
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
// Backoff returns a backoff configuration for gRPC calls // Backoff returns a backoff configuration for gRPC calls
func Backoff(ctx context.Context) backoff.BackOff { func Backoff(ctx context.Context) backoff.BackOff {
b := backoff.NewExponentialBackOff() b := backoff.NewExponentialBackOff()
@@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx) return backoff.WithContext(b, ctx)
} }
// waitForConnectionReady blocks until the connection becomes ready or fails.
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
conn.Connect()
state := conn.GetState()
for state != connectivity.Ready && state != connectivity.Shutdown {
if !conn.WaitForStateChange(ctx, state) {
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
}
state = conn.GetState()
}
if state == connectivity.Shutdown {
return ErrConnectionShutdown
}
return nil
}
// CreateConnection creates a gRPC client connection with the appropriate transport options. // CreateConnection creates a gRPC client connection with the appropriate transport options.
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
@@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
})) }))
} }
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) conn, err := grpc.NewClient(
defer cancel()
conn, err := grpc.DialContext(
connCtx,
addr, addr,
transportOption, transportOption,
WithCustomDialer(tlsEnabled, component), WithCustomDialer(tlsEnabled, component),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{ grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second, Time: 30 * time.Second,
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
}), }),
) )
if err != nil { if err != nil {
log.Printf("DialContext error: %v", err) return nil, fmt.Errorf("new client: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := waitForConnectionReady(ctx, conn); err != nil {
_ = conn.Close()
return nil, err return nil, err
} }

View File

@@ -18,7 +18,7 @@ import (
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { func WithCustomDialer(_ bool, _ string) grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
currentUser, err := user.Current() currentUser, err := user.Current()
@@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
log.Errorf("Failed to dial: %s", err)
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
} }
return conn, nil return conn, nil

View File

@@ -29,11 +29,6 @@ type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
} }
type protoMatch struct {
ips map[string]int
policyID []byte
}
// DefaultManager uses firewall manager to handle // DefaultManager uses firewall manager to handle
type DefaultManager struct { type DefaultManager struct {
firewall firewall.Manager firewall firewall.Manager
@@ -86,21 +81,14 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
} }
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules, squashedProtocols := d.squashAcceptRules(networkMap) rules := networkMap.FirewallRules
enableSSH := networkMap.PeerConfig != nil && enableSSH := networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil && networkMap.PeerConfig.SshConfig != nil &&
networkMap.PeerConfig.SshConfig.SshEnabled networkMap.PeerConfig.SshConfig.SshEnabled
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
enableSSH = enableSSH && !ok
}
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
enableSSH = enableSSH && !ok
}
// if TCP protocol rules not squashed and SSH enabled // If SSH enabled, add default firewall rule which accepts connection to any peer
// we add default firewall rule which accepts connection to any peer // in the network by SSH (TCP port defined by ssh.DefaultSSHPort).
// in the network by SSH (TCP 22 port).
if enableSSH { if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{ rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0", PeerIP: "0.0.0.0",
@@ -368,145 +356,6 @@ func (d *DefaultManager) getPeerRuleID(
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
} }
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
// to all peers in the network map to one rule which just accepts that type of the traffic.
//
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
// but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap,
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps {
totalIPs++
}
}
in := map[mgmProto.RuleProtocol]*protoMatch{}
out := map[mgmProto.RuleProtocol]*protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list.
// But we zeroed the IP's for protocol if:
// 1. Any of the rule has DROP action type.
// 2. Any of rule contains Port.
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return
}
if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{
ips: map[string]int{},
// store the first encountered PolicyID for this protocol
policyID: r.PolicyID,
}
}
// special case, when we receive this all network IP address
// it means that rules for that protocol was already optimized on the
// management side
if r.PeerIP == "0.0.0.0" {
squashedRules = append(squashedRules, r)
squashedProtocols[r.Protocol] = struct{}{}
return
}
ipset := protocols[r.Protocol].ips
if _, ok := ipset[r.PeerIP]; ok {
return
}
ipset[r.PeerIP] = i
}
for i, r := range networkMap.FirewallRules {
// calculate squash for different directions
if r.Direction == mgmProto.RuleDirection_IN {
addRuleToCalculationMap(i, r, in)
} else {
addRuleToCalculationMap(i, r, out)
}
}
// order of squashing by protocol is important
// only for their first element ALL, it must be done first
protocolOrders := []mgmProto.RuleProtocol{
mgmProto.RuleProtocol_ALL,
mgmProto.RuleProtocol_ICMP,
mgmProto.RuleProtocol_TCP,
mgmProto.RuleProtocol_UDP,
}
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
for _, protocol := range protocolOrders {
match, ok := matches[protocol]
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
// don't squash if :
// 1. Rules not cover all peers in the network
// 2. Rules cover only one peer in the network.
continue
}
// add special rule 0.0.0.0 which allows all IP's in our firewall implementations
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: direction,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: protocol,
PolicyID: match.policyID,
})
squashedProtocols[protocol] = struct{}{}
if protocol == mgmProto.RuleProtocol_ALL {
// if we have ALL traffic type squashed rule
// it allows all other type of traffic, so we can stop processing
break
}
}
}
squash(in, mgmProto.RuleDirection_IN)
squash(out, mgmProto.RuleDirection_OUT)
// if all protocol was squashed everything is allow and we can ignore all other rules
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
return squashedRules, squashedProtocols
}
if len(squashedRules) == 0 {
return networkMap.FirewallRules, squashedProtocols
}
var rules []*mgmProto.FirewallRule
// filter out rules which was squashed from final list
// if we also have other not squashed rules.
for i, r := range networkMap.FirewallRules {
if _, ok := squashedProtocols[r.Protocol]; ok {
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue
}
}
rules = append(rules, r)
}
return append(rules, squashedRules...), squashedProtocols
}
// getRuleGroupingSelector takes all rule properties except IP address to build selector // getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string { func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo) return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) {
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
err = fw.Close(nil) err = fw.Close(nil)
@@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) {
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
err = fw.Close(nil) err = fw.Close(nil)
@@ -188,492 +189,6 @@ func TestDefaultManagerStateless(t *testing.T) {
}) })
} }
func TestDefaultManagerSquashRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, 2, len(rules))
r := rules[0]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
r = rules[1]
assert.Equal(t, "0.0.0.0", r.PeerIP)
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
}
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) { func TestPortInfoEmpty(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -807,7 +322,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
err = fw.Close(nil) err = fw.Close(nil)

View File

@@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
deviceCode.VerificationURIComplete = deviceCode.VerificationURI deviceCode.VerificationURIComplete = deviceCode.VerificationURI
} }
if d.providerConfig.LoginHint != "" {
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
if deviceCode.VerificationURI != "" {
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
}
}
return deviceCode, err return deviceCode, err
} }
func appendLoginHint(uri, loginHint string) string {
if uri == "" || loginHint == "" {
return uri
}
parsedURL, err := url.Parse(uri)
if err != nil {
log.Debugf("failed to parse verification URI for login_hint: %v", err)
return uri
}
query := parsedURL.Query()
query.Set("login_hint", loginHint)
parsedURL.RawQuery = query.Encode()
return parsedURL.String()
}
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) { func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
form := url.Values{} form := url.Values{}
form.Add("client_id", d.providerConfig.ClientID) form.Add("client_id", d.providerConfig.ClientID)

View File

@@ -66,32 +66,34 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful // and if that also fails, the authentication process is deemed unsuccessful
// //
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config, hint)
} }
pkceFlow, err := authenticateWithPKCEFlow(ctx, config) pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
if err != nil { if err != nil {
// fallback to device code flow
log.Debugf("failed to initialize pkce authentication with error: %v\n", err) log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
log.Debug("falling back to device code flow") log.Debug("falling back to device code flow")
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config, hint)
} }
return pkceFlow, nil return pkceFlow, nil
} }
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
} }
pkceFlowInfo.ProviderConfig.LoginHint = hint
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
} }
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow // authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil { if err != nil {
switch s, ok := gstatus.FromError(err); { switch s, ok := gstatus.FromError(err); {
@@ -107,5 +109,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
} }
} }
deviceFlowInfo.ProviderConfig.LoginHint = hint
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
} }

View File

@@ -109,6 +109,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
params = append(params, oauth2.SetAuthURLParam("max_age", "0")) params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
} }
} }
if p.providerConfig.LoginHint != "" {
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...) authURL := p.oAuthConfig.AuthCodeURL(state, params...)

View File

@@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto" cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@@ -34,7 +35,6 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
} }
<-engineCtx.Done() <-engineCtx.Done()
c.engineMutex.Lock() c.engineMutex.Lock()
if c.engine != nil && c.engine.wgInterface != nil { engine := c.engine
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name()) c.engine = nil
if err := c.engine.Stop(); err != nil { c.engineMutex.Unlock()
if engine != nil && engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err) log.Errorf("Failed to stop engine: %v", err)
} }
c.engine = nil
} }
c.engineMutex.Unlock()
c.statusRecorder.ClientTeardown() c.statusRecorder.ClientTeardown()
backOff.Reset() backOff.Reset()
@@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType {
} }
func (c *ConnectClient) Stop() error { func (c *ConnectClient) Stop() error {
if c == nil { engine := c.Engine()
return nil if engine != nil {
if err := engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
} }
c.engineMutex.Lock()
defer c.engineMutex.Unlock()
if c.engine == nil {
return nil
}
if err := c.engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
return nil return nil
} }

View File

@@ -44,10 +44,12 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided. ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
config.txt: Anonymized configuration information of the NetBird client. config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
state.json: Anonymized client state dump containing netbird states. state.json: Anonymized client state dump containing netbird states for the active profile.
mutex.prof: Mutex profiling information. mutex.prof: Mutex profiling information.
goroutine.prof: Goroutine profiling information. goroutine.prof: Goroutine profiling information.
block.prof: Block profiling information. block.prof: Block profiling information.
@@ -184,6 +186,20 @@ The ip_rules.txt file contains detailed IP routing rule information:
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing. The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged. For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
DNS Configuration
The debug bundle includes platform-specific DNS configuration files:
resolv.conf (Unix systems):
- Contains DNS resolver configuration from /etc/resolv.conf
- Includes nameserver entries, search domains, and resolver options
- All IP addresses and domain names are anonymized following the same rules as other files
scutil_dns.txt (macOS only):
- Contains detailed DNS configuration from scutil --dns
- Shows DNS configuration for all network interfaces
- Includes search domains, nameservers, and DNS resolver settings
- All IP addresses and domain names are anonymized
` `
const ( const (
@@ -357,6 +373,10 @@ func (g *BundleGenerator) addSystemInfo() {
if err := g.addFirewallRules(); err != nil { if err := g.addFirewallRules(); err != nil {
log.Errorf("failed to add firewall rules to debug bundle: %v", err) log.Errorf("failed to add firewall rules to debug bundle: %v", err)
} }
if err := g.addDNSInfo(); err != nil {
log.Errorf("failed to add DNS info to debug bundle: %v", err)
}
} }
func (g *BundleGenerator) addReadme() error { func (g *BundleGenerator) addReadme() error {
@@ -564,6 +584,8 @@ func (g *BundleGenerator) addStateFile() error {
return nil return nil
} }
log.Debugf("Adding state file from: %s", path)
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {

View File

@@ -0,0 +1,53 @@
//go:build darwin && !ios
package debug
import (
"bytes"
"context"
"fmt"
"os/exec"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// addDNSInfo collects and adds DNS configuration information to the archive
func (g *BundleGenerator) addDNSInfo() error {
if err := g.addResolvConf(); err != nil {
log.Errorf("failed to add resolv.conf: %v", err)
}
if err := g.addScutilDNS(); err != nil {
log.Errorf("failed to add scutil DNS output: %v", err)
}
return nil
}
func (g *BundleGenerator) addScutilDNS() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "scutil", "--dns")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("execute scutil --dns: %w", err)
}
if len(bytes.TrimSpace(output)) == 0 {
return fmt.Errorf("no scutil DNS output")
}
content := string(output)
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
return fmt.Errorf("add scutil DNS output to zip: %w", err)
}
return nil
}

View File

@@ -5,3 +5,7 @@ package debug
func (g *BundleGenerator) addRoutes() error { func (g *BundleGenerator) addRoutes() error {
return nil return nil
} }
func (g *BundleGenerator) addDNSInfo() error {
return nil
}

View File

@@ -0,0 +1,16 @@
//go:build unix && !darwin && !android
package debug
import (
log "github.com/sirupsen/logrus"
)
// addDNSInfo collects and adds DNS configuration information to the archive
func (g *BundleGenerator) addDNSInfo() error {
if err := g.addResolvConf(); err != nil {
log.Errorf("failed to add resolv.conf: %v", err)
}
return nil
}

View File

@@ -0,0 +1,7 @@
//go:build !unix
package debug
func (g *BundleGenerator) addDNSInfo() error {
return nil
}

View File

@@ -0,0 +1,29 @@
//go:build unix && !android
package debug
import (
"fmt"
"os"
"strings"
)
const resolvConfPath = "/etc/resolv.conf"
func (g *BundleGenerator) addResolvConf() error {
data, err := os.ReadFile(resolvConfPath)
if err != nil {
return fmt.Errorf("read %s: %w", resolvConfPath, err)
}
content := string(data)
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
return fmt.Errorf("add resolv.conf to zip: %w", err)
}
return nil
}

View File

@@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct {
Scope string Scope string
// UseIDToken indicates if the id token should be used for authentication // UseIDToken indicates if the id token should be used for authentication
UseIDToken bool UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
} }
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it // GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it

View File

@@ -13,6 +13,7 @@ import (
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool {
} }
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
var ( var (
searchDomains []string searchDomains []string
matchDomains []string matchDomains []string
) )
err = s.recordSystemDNSSettings(true) if err := s.recordSystemDNSSettings(true); err != nil {
if err != nil {
log.Errorf("unable to update record of System's DNS config: %s", err.Error()) log.Errorf("unable to update record of System's DNS config: %s", err.Error())
} }
if config.RouteAll { if config.RouteAll {
searchDomains = append(searchDomains, "\"\"") searchDomains = append(searchDomains, "\"\"")
err = s.addLocalDNS() if err := s.addLocalDNS(); err != nil {
if err != nil { log.Warnf("failed to add local DNS: %v", err)
log.Infof("failed to enable split DNS")
} }
s.updateState(stateManager)
} }
for _, dConf := range config.Domains { for _, dConf := range config.Domains {
@@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
} }
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
var err error
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
} else { } else {
@@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
if err != nil { if err != nil {
return fmt.Errorf("add match domains: %w", err) return fmt.Errorf("add match domains: %w", err)
} }
s.updateState(stateManager)
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
if len(searchDomains) != 0 { if len(searchDomains) != 0 {
@@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
if err != nil { if err != nil {
return fmt.Errorf("add search domains: %w", err) return fmt.Errorf("add search domains: %w", err)
} }
s.updateState(stateManager)
if err := s.flushDNSCache(); err != nil { if err := s.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err) log.Errorf("failed to flush DNS cache: %v", err)
@@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
return nil return nil
} }
func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
}
func (s *systemConfigurator) string() string { func (s *systemConfigurator) string() string {
return "scutil" return "scutil"
} }
@@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
func (s *systemConfigurator) addLocalDNS() error { func (s *systemConfigurator) addLocalDNS() error {
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
if err := s.recordSystemDNSSettings(true); err != nil { if err := s.recordSystemDNSSettings(true); err != nil {
log.Errorf("Unable to get system DNS configuration")
return fmt.Errorf("recordSystemDNSSettings(): %w", err) return fmt.Errorf("recordSystemDNSSettings(): %w", err)
} }
} }
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
if err != nil {
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
}
} else {
log.Info("Not enabling local DNS server") log.Info("Not enabling local DNS server")
return nil
}
if err := s.addSearchDomains(
localKey,
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
); err != nil {
return fmt.Errorf("add search domains: %w", err)
} }
return nil return nil

View File

@@ -0,0 +1,111 @@
//go:build !ios
package dns
import (
"context"
"net/netip"
"os/exec"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
if testing.Short() {
t.Skip("skipping scutil integration test in short mode")
}
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
defer func() {
require.NoError(t, sm.Stop(context.Background()))
}()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
config := HostDNSConfig{
ServerIP: netip.MustParseAddr("100.64.0.1"),
ServerPort: 53,
RouteAll: true,
Domains: []DomainConfig{
{Domain: "example.com", MatchOnly: true},
},
}
err := configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
require.NoError(t, sm.PersistState(context.Background()))
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
defer func() {
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}()
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
if exists {
t.Logf("Key %s exists before cleanup", key)
}
}
sm2 := statemanager.New(stateFile)
sm2.RegisterState(&ShutdownState{})
err = sm2.LoadState(&ShutdownState{})
require.NoError(t, err)
state := sm2.GetState(&ShutdownState{})
if state == nil {
t.Skip("State not saved, skipping cleanup test")
}
shutdownState, ok := state.(*ShutdownState)
require.True(t, ok)
err = shutdownState.Cleanup()
require.NoError(t, err)
for _, key := range []string{searchKey, matchKey, localKey} {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
}
}
func checkDNSKeyExists(key string) (bool, error) {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
output, err := cmd.CombinedOutput()
if err != nil {
if strings.Contains(string(output), "No such key") {
return false, nil
}
return false, err
}
return !strings.Contains(string(output), "No such key"), nil
}
func removeTestDNSKey(key string) error {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
_, err := cmd.CombinedOutput()
return err
}

View File

@@ -17,6 +17,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/winregistry"
) )
var ( var (
@@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
if err := stateManager.UpdateState(&ShutdownState{ r.updateState(stateManager)
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
var searchDomains, matchDomains []string var searchDomains, matchDomains []string
for _, dConf := range config.Domains { for _, dConf := range config.Domains {
@@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
} }
if err := r.removeDNSMatchPolicies(); err != nil {
log.Errorf("cleanup old dns match policies: %s", err)
}
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
if err != nil { if err != nil {
@@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
} }
r.nrptEntryCount = count r.nrptEntryCount = count
} else { } else {
if err := r.removeDNSMatchPolicies(); err != nil {
return fmt.Errorf("remove dns match policies: %w", err)
}
r.nrptEntryCount = 0 r.nrptEntryCount = 0
} }
if err := stateManager.UpdateState(&ShutdownState{ r.updateState(stateManager)
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
if err := r.updateSearchDomains(searchDomains); err != nil { if err := r.updateSearchDomains(searchDomains); err != nil {
return fmt.Errorf("update search domains: %w", err) return fmt.Errorf("update search domains: %w", err)
@@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return nil return nil
} }
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,
GPO: r.gpo,
NRPTEntryCount: r.nrptEntryCount,
}); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
}
}
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
return fmt.Errorf("adding dns setup for all failed: %w", err) return fmt.Errorf("adding dns setup for all failed: %w", err)
@@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s
return fmt.Errorf("remove existing dns policy: %w", err) return fmt.Errorf("remove existing dns policy: %w", err)
} }
regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE)
if err != nil { if err != nil {
return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err)
} }
defer closer(regKey) defer closer(regKey)

View File

@@ -0,0 +1,102 @@
package dns
import (
"fmt"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows/registry"
)
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
// when the number of match domains decreases between configuration changes.
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
config5 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
{Domain: "domain3.com", MatchOnly: true},
{Domain: "domain4.com", MatchOnly: true},
{Domain: "domain5.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config5, nil)
require.NoError(t, err)
// Verify all 5 entries exist
for i := 0; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after first config", i)
}
config2 := HostDNSConfig{
ServerIP: testIP,
Domains: []DomainConfig{
{Domain: "domain1.com", MatchOnly: true},
{Domain: "domain2.com", MatchOnly: true},
},
}
err = cfg.applyDNSConfig(config2, nil)
require.NoError(t, err)
// Verify first 2 entries exist
for i := 0; i < 2; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "Entry %d should exist after second config", i)
}
// Verify entries 2-4 are cleaned up
for i := 2; i < 5; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
}
}
func registryKeyExists(path string) (bool, error) {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
if err != nil {
if err == registry.ErrNotExist {
return false, nil
}
return false, err
}
k.Close()
return true, nil
}
func cleanupRegistryKeys(*testing.T) {
cfg := &registryConfigurator{nrptEntryCount: 10}
_ = cfg.removeDNSMatchPolicies()
}

View File

@@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {
// DefaultServer dns server object // DefaultServer dns server object
type DefaultServer struct { type DefaultServer struct {
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
shutdownWg sync.WaitGroup
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running. // disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
// This is different from ServiceEnable=false from management which completely disables the DNS service. // This is different from ServiceEnable=false from management which completely disables the DNS service.
disableSys bool disableSys bool
@@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
// Stop stops the server // Stop stops the server
func (s *DefaultServer) Stop() { func (s *DefaultServer) Stop() {
s.ctxCancel() s.ctxCancel()
s.shutdownWg.Wait()
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
@@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.applyHostConfig() s.applyHostConfig()
s.shutdownWg.Add(1)
go func() { go func() {
// persist dns state right away defer s.shutdownWg.Done()
if err := s.stateManager.PersistState(s.ctx); err != nil { if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err) log.Errorf("Failed to persist dns state: %v", err)
} }

View File

@@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err return nil, err
} }
pf, err := uspfilter.Create(wgIface, false, flowLogger) pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
if err != nil { if err != nil {
t.Fatalf("failed to create uspfilter: %v", err) t.Fatalf("failed to create uspfilter: %v", err)
return nil, err return nil, err

View File

@@ -7,6 +7,7 @@ import (
) )
type ShutdownState struct { type ShutdownState struct {
CreatedKeys []string
} }
func (s *ShutdownState) Name() string { func (s *ShutdownState) Name() string {
@@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create host manager: %w", err) return fmt.Errorf("create host manager: %w", err)
} }
for _, key := range s.CreatedKeys {
manager.createdKeys[key] = struct{}{}
}
if err := manager.restoreUncleanShutdownDNS(); err != nil { if err := manager.restoreUncleanShutdownDNS(); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }

View File

@@ -83,4 +83,3 @@ func TestCacheMiss(t *testing.T) {
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok) t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
} }
} }

View File

@@ -14,6 +14,7 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -33,7 +34,7 @@ type firewaller interface {
} }
type DNSForwarder struct { type DNSForwarder struct {
listenAddress string listenAddress netip.AddrPort
ttl uint32 ttl uint32
statusRecorder *peer.Status statusRecorder *peer.Status
@@ -47,9 +48,11 @@ type DNSForwarder struct {
firewall firewaller firewall firewaller
resolver resolver resolver resolver
cache *cache cache *cache
wgIface wgIface
} }
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{ return &DNSForwarder{
listenAddress: listenAddress, listenAddress: listenAddress,
@@ -58,30 +61,46 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
resolver: net.DefaultResolver, resolver: net.DefaultResolver,
cache: newCache(), cache: newCache(),
wgIface: wgIface,
} }
} }
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
log.Infof("starting DNS forwarder on address=%s", f.listenAddress) var netstackNet *netstack.Net
if f.wgIface != nil {
netstackNet = f.wgIface.GetNet()
}
addrDesc := f.listenAddress.String()
if netstackNet != nil {
addrDesc = fmt.Sprintf("netstack %s", f.listenAddress)
}
log.Infof("starting DNS forwarder on address=%s", addrDesc)
udpLn, err := f.createUDPListener(netstackNet)
if err != nil {
return fmt.Errorf("create UDP listener: %w", err)
}
tcpLn, err := f.createTCPListener(netstackNet)
if err != nil {
return fmt.Errorf("create TCP listener: %w", err)
}
// UDP server
mux := dns.NewServeMux() mux := dns.NewServeMux()
f.mux = mux f.mux = mux
mux.HandleFunc(".", f.handleDNSQueryUDP) mux.HandleFunc(".", f.handleDNSQueryUDP)
f.dnsServer = &dns.Server{ f.dnsServer = &dns.Server{
Addr: f.listenAddress, PacketConn: udpLn,
Net: "udp", Handler: mux,
Handler: mux,
} }
// TCP server
tcpMux := dns.NewServeMux() tcpMux := dns.NewServeMux()
f.tcpMux = tcpMux f.tcpMux = tcpMux
tcpMux.HandleFunc(".", f.handleDNSQueryTCP) tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
f.tcpServer = &dns.Server{ f.tcpServer = &dns.Server{
Addr: f.listenAddress, Listener: tcpLn,
Net: "tcp", Handler: tcpMux,
Handler: tcpMux,
} }
f.UpdateDomains(entries) f.UpdateDomains(entries)
@@ -89,18 +108,33 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
errCh := make(chan error, 2) errCh := make(chan error, 2)
go func() { go func() {
log.Infof("DNS UDP listener running on %s", f.listenAddress) log.Infof("DNS UDP listener running on %s", addrDesc)
errCh <- f.dnsServer.ListenAndServe() errCh <- f.dnsServer.ActivateAndServe()
}() }()
go func() { go func() {
log.Infof("DNS TCP listener running on %s", f.listenAddress) log.Infof("DNS TCP listener running on %s", addrDesc)
errCh <- f.tcpServer.ListenAndServe() errCh <- f.tcpServer.ActivateAndServe()
}() }()
// return the first error we get (e.g. bind failure or shutdown)
return <-errCh return <-errCh
} }
func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) {
if netstackNet != nil {
return netstackNet.ListenUDPAddrPort(f.listenAddress)
}
return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress))
}
func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) {
if netstackNet != nil {
return netstackNet.ListenTCPAddrPort(f.listenAddress)
}
return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress))
}
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()

View File

@@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil) mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
} }
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
forwarder.resolver = mockResolver forwarder.resolver = mockResolver
d, err := domain.FromString(tt.configuredDomain) d, err := domain.FromString(tt.configuredDomain)
@@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
mockResolver := &MockResolver{} mockResolver := &MockResolver{}
// Set up forwarder // Set up forwarder
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
forwarder.resolver = mockResolver forwarder.resolver = mockResolver
// Create entries and track sets // Create entries and track sets
@@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
mockFirewall := &MockFirewall{} mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{} mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
forwarder.resolver = mockResolver forwarder.resolver = mockResolver
// Configure a single domain // Configure a single domain
@@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
d, err := domain.FromString(tt.configured) d, err := domain.FromString(tt.configured)
require.NoError(t, err) require.NoError(t, err)
@@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
func TestDNSForwarder_TCPTruncation(t *testing.T) { func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set // Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{} mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = mockResolver forwarder.resolver = mockResolver
d, _ := domain.FromString("example.com") d, _ := domain.FromString("example.com")
@@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
// a subsequent upstream failure still returns a successful response from cache. // a subsequent upstream failure still returns a successful response from cache.
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
mockResolver := &MockResolver{} mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = mockResolver forwarder.resolver = mockResolver
d, err := domain.FromString("example.com") d, err := domain.FromString("example.com")
@@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
// Verifies that cache normalization works across casing and trailing dot variations. // Verifies that cache normalization works across casing and trailing dot variations.
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
mockResolver := &MockResolver{} mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = mockResolver forwarder.resolver = mockResolver
d, err := domain.FromString("ExAmPlE.CoM") d, err := domain.FromString("ExAmPlE.CoM")
@@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
mockFirewall := &MockFirewall{} mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{} mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
forwarder.resolver = mockResolver forwarder.resolver = mockResolver
// Set up complex overlapping patterns // Set up complex overlapping patterns
@@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
mockFirewall := &MockFirewall{} mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{} mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
forwarder.resolver = mockResolver forwarder.resolver = mockResolver
d, err := domain.FromString("example.com") d, err := domain.FromString("example.com")
@@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
func TestDNSForwarder_EmptyQuery(t *testing.T) { func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions // Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
query := &dns.Msg{} query := &dns.Msg{}
// Don't set any question // Don't set any question

View File

@@ -4,27 +4,34 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"sync" "net/netip"
"os"
"strconv"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
var ( const (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also dnsTTL = 60
listenPort uint16 = 5353 envServerPort = "NB_DNS_FORWARDER_PORT"
listenPortMu sync.RWMutex
) )
const ( // wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder.
dnsTTL = 60 //seconds type wgIface interface {
) GetNet() *netstack.Net
Address() wgaddr.Address
}
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
type ForwarderEntry struct { type ForwarderEntry struct {
@@ -36,28 +43,30 @@ type ForwarderEntry struct {
type Manager struct { type Manager struct {
firewall firewall.Manager firewall firewall.Manager
statusRecorder *peer.Status statusRecorder *peer.Status
wgIface wgIface
serverPort uint16
fwRules []firewall.Rule fwRules []firewall.Rule
tcpRules []firewall.Rule tcpRules []firewall.Rule
dnsForwarder *DNSForwarder dnsForwarder *DNSForwarder
} }
func ListenPort() uint16 { func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager {
listenPortMu.RLock() serverPort := nbdns.ForwarderServerPort
defer listenPortMu.RUnlock() if envPort := os.Getenv(envServerPort); envPort != "" {
return listenPort if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
} serverPort = uint16(port)
log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort)
} else {
log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort)
}
}
func SetListenPort(port uint16) {
listenPortMu.Lock()
listenPort = port
listenPortMu.Unlock()
}
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
return &Manager{ return &Manager{
firewall: fw, firewall: fw,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgIface: wgIface,
serverPort: serverPort,
} }
} }
@@ -71,7 +80,25 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
return err return err
} }
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) localAddr := m.wgIface.Address().IP
if localAddr.IsValid() && m.firewall != nil {
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
} else {
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
}
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
} else {
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
}
}
listenAddress := netip.AddrPortFrom(localAddr, m.serverPort)
m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface)
go func() { go func() {
if err := m.dnsForwarder.Listen(fwdEntries); err != nil { if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
// todo handle close error if it is exists // todo handle close error if it is exists
@@ -96,6 +123,20 @@ func (m *Manager) Stop(ctx context.Context) error {
} }
var mErr *multierror.Error var mErr *multierror.Error
localAddr := m.wgIface.Address().IP
if localAddr.IsValid() && m.firewall != nil {
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
}
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
}
}
m.unregisterNetstackServices()
if err := m.dropDNSFirewall(); err != nil { if err := m.dropDNSFirewall(); err != nil {
mErr = multierror.Append(mErr, err) mErr = multierror.Append(mErr, err)
} }
@@ -111,7 +152,7 @@ func (m *Manager) Stop(ctx context.Context) error {
func (m *Manager) allowDNSFirewall() error { func (m *Manager) allowDNSFirewall() error {
dport := &firewall.Port{ dport := &firewall.Port{
IsRange: false, IsRange: false,
Values: []uint16{ListenPort()}, Values: []uint16{m.serverPort},
} }
if m.firewall == nil { if m.firewall == nil {
@@ -120,21 +161,50 @@ func (m *Manager) allowDNSFirewall() error {
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
if err != nil { if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err) return fmt.Errorf("add udp firewall rule: %w", err)
return err
} }
m.fwRules = dnsRules
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "") tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
if err != nil { if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err) return fmt.Errorf("add tcp firewall rule: %w", err)
return err
} }
if err := m.firewall.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
m.fwRules = dnsRules
m.tcpRules = tcpRules m.tcpRules = tcpRules
m.registerNetstackServices()
return nil return nil
} }
func (m *Manager) registerNetstackServices() {
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
if registrar, ok := m.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, m.serverPort)
registrar.RegisterNetstackService(nftypes.UDP, m.serverPort)
log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
}
}
}
func (m *Manager) unregisterNetstackServices() {
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
if registrar, ok := m.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort)
registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort)
log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
}
}
}
func (m *Manager) dropDNSFirewall() error { func (m *Manager) dropDNSFirewall() error {
var mErr *multierror.Error var mErr *multierror.Error
for _, rule := range m.fwRules { for _, rule := range m.fwRules {

View File

@@ -148,6 +148,8 @@ type Engine struct {
// syncMsgMux is used to guarantee sequential Management Service message processing // syncMsgMux is used to guarantee sequential Management Service message processing
syncMsgMux *sync.Mutex syncMsgMux *sync.Mutex
// sshMux protects sshServer field access
sshMux sync.Mutex
config *EngineConfig config *EngineConfig
mobileDep MobileDependency mobileDep MobileDependency
@@ -200,11 +202,12 @@ type Engine struct {
flowManager nftypes.FlowManager flowManager nftypes.FlowManager
// WireGuard interface monitor // WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
// dns forwarder port // shutdownWg tracks all long-running goroutines to ensure clean shutdown
dnsFwdPort uint16 shutdownWg sync.WaitGroup
probeStunTurn *relay.StunTurnProbe
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -247,7 +250,7 @@ func NewEngine(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
checks: checks, checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
dnsFwdPort: dnsfwd.ListenPort(), probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
} }
sm := profilemanager.NewServiceManager("") sm := profilemanager.NewServiceManager("")
@@ -299,17 +302,12 @@ func (e *Engine) Stop() error {
e.ingressGatewayMgr = nil e.ingressGatewayMgr = nil
} }
e.stopDNSForwarder()
if e.routeManager != nil { if e.routeManager != nil {
e.routeManager.Stop(e.stateManager) e.routeManager.Stop(e.stateManager)
} }
if e.dnsForwardMgr != nil {
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
if e.srWatcher != nil { if e.srWatcher != nil {
e.srWatcher.Close() e.srWatcher.Close()
} }
@@ -326,10 +324,6 @@ func (e *Engine) Stop() error {
e.cancel() e.cancel()
} }
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)
e.close() e.close()
// stop flow manager after wg interface is gone // stop flow manager after wg interface is gone
@@ -337,8 +331,6 @@ func (e *Engine) Stop() error {
e.flowManager.Close() e.flowManager.Close()
} }
log.Infof("stopped Netbird Engine")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel() defer cancel()
@@ -349,12 +341,52 @@ func (e *Engine) Stop() error {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
// Stop WireGuard interface monitor and wait for it to exit timeout := e.calculateShutdownTimeout()
e.wgIfaceMonitorWg.Wait() log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil return nil
} }
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
func (e *Engine) calculateShutdownTimeout() time.Duration {
peerCount := len(e.peerStore.PeersPubKey())
baseTimeout := 10 * time.Second
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
timeout := baseTimeout + perPeerTimeout
maxTimeout := 30 * time.Second
if timeout > maxTimeout {
timeout = maxTimeout
}
return timeout
}
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here. // Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service // However, they will be established once an event with a list of peers to connect to will be received from Management Service
@@ -484,14 +516,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes // monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor() e.wgIfaceMonitor = NewWGIfaceMonitor()
e.wgIfaceMonitorWg.Add(1) e.shutdownWg.Add(1)
go func() { go func() {
defer e.wgIfaceMonitorWg.Done() defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err) log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.restartEngine() e.triggerClientRestart()
} else if err != nil { } else if err != nil {
log.Warnf("WireGuard interface monitor: %s", err) log.Warnf("WireGuard interface monitor: %s", err)
} }
@@ -507,7 +539,7 @@ func (e *Engine) createFirewall() error {
} }
var err error var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil || e.firewall == nil { if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err) log.Errorf("failed creating firewall manager: %s", err)
return nil return nil
@@ -675,9 +707,11 @@ func (e *Engine) removeAllPeers() error {
func (e *Engine) removePeer(peerKey string) error { func (e *Engine) removePeer(peerKey string) error {
log.Debugf("removing peer from engine %s", peerKey) log.Debugf("removing peer from engine %s", peerKey)
e.sshMux.Lock()
if !isNil(e.sshServer) { if !isNil(e.sshServer) {
e.sshServer.RemoveAuthorizedKey(peerKey) e.sshServer.RemoveAuthorizedKey(peerKey)
} }
e.sshMux.Unlock()
e.connMgr.RemovePeerConn(peerKey) e.connMgr.RemovePeerConn(peerKey)
@@ -879,6 +913,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS) log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil return nil
} }
e.sshMux.Lock()
// start SSH server if it wasn't running // start SSH server if it wasn't running
if isNil(e.sshServer) { if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort) listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
@@ -886,34 +921,42 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort) listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
} }
// nil sshServer means it has not yet been started // nil sshServer means it has not yet been started
var err error server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil { if err != nil {
e.sshMux.Unlock()
return fmt.Errorf("create ssh server: %w", err) return fmt.Errorf("create ssh server: %w", err)
} }
e.sshServer = server
e.sshMux.Unlock()
go func() { go func() {
// blocking // blocking
err = e.sshServer.Start() err = server.Start()
if err != nil { if err != nil {
// will throw error when we stop it even if it is a graceful stop // will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err) log.Debugf("stopped SSH server with error %v", err)
} }
e.syncMsgMux.Lock() e.sshMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshServer = nil e.sshServer = nil
e.sshMux.Unlock()
log.Infof("stopped SSH server") log.Infof("stopped SSH server")
}() }()
} else { } else {
e.sshMux.Unlock()
log.Debugf("SSH server is already running") log.Debugf("SSH server is already running")
} }
} else if !isNil(e.sshServer) { } else {
// Disable SSH server request, so stop it if it was running e.sshMux.Lock()
err := e.sshServer.Stop() if !isNil(e.sshServer) {
if err != nil { // Disable SSH server request, so stop it if it was running
log.Warnf("failed to stop SSH server %v", err) err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
} }
e.sshServer = nil e.sshMux.Unlock()
} }
return nil return nil
} }
@@ -950,7 +993,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service // receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it. // E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() { func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() { go func() {
defer e.shutdownWg.Done()
info, err := system.GetInfoWithChecks(e.ctx, e.checks) info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil { if err != nil {
log.Warnf("failed to get system info with checks: %v", err) log.Warnf("failed to get system info with checks: %v", err)
@@ -1060,10 +1105,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
protoDNSConfig = &mgmProto.DNSConfig{} protoDNSConfig = &mgmProto.DNSConfig{}
} }
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
log.Errorf("failed to update dns server, err: %v", err) log.Errorf("failed to update dns server, err: %v", err)
} }
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
// apply routes first, route related actions might depend on routing being enabled // apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes()) routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1084,7 +1133,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Ingress forward rules // Ingress forward rules
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
@@ -1122,6 +1171,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.statusRecorder.FinishPeerListModifications() e.statusRecorder.FinishPeerListModifications()
// update SSHServer by adding remote peer SSH keys // update SSHServer by adding remote peer SSH keys
e.sshMux.Lock()
if !isNil(e.sshServer) { if !isNil(e.sshServer) {
for _, config := range networkMap.GetRemotePeers() { for _, config := range networkMap.GetRemotePeers() {
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil { if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
@@ -1132,6 +1182,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
} }
} }
e.sshMux.Unlock()
} }
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
@@ -1208,10 +1259,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
if forwarderPort == 0 {
forwarderPort = nbdns.ForwarderClientPort
}
dnsUpdate := nbdns.Config{ dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(), ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0), CustomZones: make([]nbdns.CustomZone, 0),
NameServerGroups: make([]*nbdns.NameServerGroup, 0), NameServerGroups: make([]*nbdns.NameServerGroup, 0),
ForwarderPort: forwarderPort,
} }
for _, zone := range protoDNSConfig.GetCustomZones() { for _, zone := range protoDNSConfig.GetCustomZones() {
@@ -1368,7 +1425,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers // receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() { func (e *Engine) receiveSignalEvents() {
e.shutdownWg.Add(1)
go func() { go func() {
defer e.shutdownWg.Done()
// connect to a stream of messages coming from the signal server // connect to a stream of messages coming from the signal server
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error { err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
@@ -1485,12 +1544,14 @@ func (e *Engine) close() {
e.statusRecorder.SetWgIface(nil) e.statusRecorder.SetWgIface(nil)
} }
e.sshMux.Lock()
if !isNil(e.sshServer) { if !isNil(e.sshServer) {
err := e.sshServer.Stop() err := e.sshServer.Stop()
if err != nil { if err != nil {
log.Warnf("failed stopping the SSH server: %v", err) log.Warnf("failed stopping the SSH server: %v", err)
} }
} }
e.sshMux.Unlock()
if e.firewall != nil { if e.firewall != nil {
err := e.firewall.Close(e.stateManager) err := e.firewall.Close(e.stateManager)
@@ -1667,7 +1728,7 @@ func (e *Engine) getRosenpassAddr() string {
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// and updates the status recorder with the latest states. // and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes() bool { func (e *Engine) RunHealthProbes(waitForResult bool) bool {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
signalHealthy := e.signal.IsHealthy() signalHealthy := e.signal.IsHealthy()
@@ -1699,8 +1760,12 @@ func (e *Engine) RunHealthProbes() bool {
} }
e.syncMsgMux.Unlock() e.syncMsgMux.Unlock()
var results []relay.ProbeResult
results := e.probeICE(stuns, turns) if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results) e.statusRecorder.UpdateRelayStates(results)
relayHealthy := true relayHealthy := true
@@ -1717,15 +1782,10 @@ func (e *Engine) RunHealthProbes() bool {
return allHealthy return allHealthy
} }
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult { // triggerClientRestart triggers a full client restart by cancelling the client context.
return append( // Note: This does NOT just restart the engine - it cancels the entire client context,
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns), // which causes the connect client's retry loop to create a completely new engine.
relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)..., func (e *Engine) triggerClientRestart() {
)
}
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
@@ -1747,7 +1807,9 @@ func (e *Engine) startNetworkMonitor() {
} }
e.networkMonitor = networkmonitor.New() e.networkMonitor = networkmonitor.New()
e.shutdownWg.Add(1)
go func() { go func() {
defer e.shutdownWg.Done()
if err := e.networkMonitor.Listen(e.ctx); err != nil { if err := e.networkMonitor.Listen(e.ctx); err != nil {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped") log.Infof("network monitor stopped")
@@ -1757,8 +1819,8 @@ func (e *Engine) startNetworkMonitor() {
return return
} }
log.Infof("Network monitor: detected network change, restarting engine") log.Infof("Network monitor: detected network change, triggering client restart")
e.restartEngine() e.triggerClientRestart()
}() }()
} }
@@ -1843,64 +1905,50 @@ func (e *Engine) GetWgAddr() netip.Addr {
func (e *Engine) updateDNSForwarder( func (e *Engine) updateDNSForwarder(
enabled bool, enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry, fwdEntries []*dnsfwd.ForwarderEntry,
forwarderPort uint16,
) { ) {
if e.config.DisableServerRoutes { if e.config.DisableServerRoutes {
return return
} }
if forwarderPort > 0 {
dnsfwd.SetListenPort(forwarderPort)
}
if !enabled { if !enabled {
if e.dnsForwardMgr == nil { e.stopDNSForwarder()
return
}
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
return return
} }
if len(fwdEntries) > 0 { if len(fwdEntries) > 0 {
switch { if e.dnsForwardMgr == nil {
case e.dnsForwardMgr == nil: e.startDNSForwarder(fwdEntries)
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) } else {
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
log.Infof("started domain router service with %d entries", len(fwdEntries))
case e.dnsFwdPort != forwarderPort:
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort)
e.restartDnsFwd(fwdEntries, forwarderPort)
e.dnsFwdPort = forwarderPort
default:
e.dnsForwardMgr.UpdateDomains(fwdEntries) e.dnsForwardMgr.UpdateDomains(fwdEntries)
} }
} else if e.dnsForwardMgr != nil { } else if e.dnsForwardMgr != nil {
log.Infof("disable domain router service") log.Infof("disable domain router service")
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { e.stopDNSForwarder()
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
} }
} }
func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) { func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) {
log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface)
// stop and start the forwarder to apply the new port
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err) log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil e.dnsForwardMgr = nil
return
} }
log.Infof("started domain router service with %d entries", len(fwdEntries))
}
func (e *Engine) stopDNSForwarder() {
if e.dnsForwardMgr == nil {
return
}
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
} }
func (e *Engine) GetNet() (*netstack.Net, error) { func (e *Engine) GetNet() (*netstack.Net, error) {

View File

@@ -10,10 +10,10 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/store"
"github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/dns"
) )
type rcvChan chan *types.EventFields type rcvChan chan *types.EventFields
@@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
// check dns collection // check dns collection
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { if !l.dnsCollection.Load() && event.Protocol == types.UDP &&
(event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) {
return false return false
} }

View File

@@ -24,6 +24,7 @@ import (
// Manager handles netflow tracking and logging // Manager handles netflow tracking and logging
type Manager struct { type Manager struct {
mux sync.Mutex mux sync.Mutex
shutdownWg sync.WaitGroup
logger nftypes.FlowLogger logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker conntrack nftypes.ConnTracker
@@ -105,8 +106,15 @@ func (m *Manager) resetClient() error {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel m.cancel = cancel
go m.receiveACKs(ctx, flowClient) m.shutdownWg.Add(2)
go m.startSender(ctx) go func() {
defer m.shutdownWg.Done()
m.receiveACKs(ctx, flowClient)
}()
go func() {
defer m.shutdownWg.Done()
m.startSender(ctx)
}()
return nil return nil
} }
@@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
// Close cleans up all resources // Close cleans up all resources
func (m *Manager) Close() { func (m *Manager) Close() {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock()
if err := m.disableFlow(); err != nil { if err := m.disableFlow(); err != nil {
log.Warnf("failed to disable flow manager: %v", err) log.Warnf("failed to disable flow manager: %v", err)
} }
m.mux.Unlock()
m.shutdownWg.Wait()
} }
// GetLogger returns the flow logger // GetLogger returns the flow logger

View File

@@ -1,4 +1,4 @@
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd //go:build dragonfly || freebsd || netbsd || openbsd
package networkmonitor package networkmonitor
@@ -6,21 +6,19 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"syscall"
"unsafe"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) fd, err := prepareFd()
if err != nil { if err != nil {
return fmt.Errorf("open routing socket: %v", err) return fmt.Errorf("open routing socket: %v", err)
} }
defer func() { defer func() {
err := unix.Close(fd) err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) { if err != nil && !errors.Is(err, unix.EBADF) {
@@ -28,72 +26,5 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
} }
}() }()
for { return routeCheck(ctx, fd, nexthopv4, nexthopv6)
select {
case <-ctx.Done():
return ctx.Err()
default:
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}
if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
return nil
}
}
}
}
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return systemops.MsgToRoute(msg)
} }

View File

@@ -0,0 +1,92 @@
//go:build dragonfly || freebsd || netbsd || openbsd || darwin
package networkmonitor
import (
"context"
"errors"
"fmt"
"syscall"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func prepareFd() (int, error) {
return unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
}
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}
if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
return nil
}
}
}
}
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return systemops.MsgToRoute(msg)
}

View File

@@ -0,0 +1,149 @@
//go:build darwin && !ios
package networkmonitor
import (
"context"
"errors"
"fmt"
"hash/fnv"
"os/exec"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// todo: refactor to not use static functions
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
fd, err := prepareFd()
if err != nil {
return fmt.Errorf("open routing socket: %v", err)
}
defer func() {
if err := unix.Close(fd); err != nil {
if !errors.Is(err, unix.EBADF) {
log.Warnf("Network monitor: failed to close routing socket: %v", err)
}
}
}()
routeChanged := make(chan struct{})
go func() {
_ = routeCheck(ctx, fd, nexthopv4, nexthopv6)
close(routeChanged)
}()
wakeUp := make(chan struct{})
go func() {
wakeUpListen(ctx)
close(wakeUp)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-routeChanged:
if ctx.Err() != nil {
return ctx.Err()
}
log.Infof("route change detected")
return nil
case <-wakeUp:
if ctx.Err() != nil {
return ctx.Err()
}
log.Infof("wakeup detected")
return nil
}
}
func wakeUpListen(ctx context.Context) {
log.Infof("start to watch for system wakeups")
var (
initialHash uint32
err error
)
// Keep retrying until initial sysctl succeeds or context is canceled
for {
select {
case <-ctx.Done():
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
return
default:
initialHash, err = readSleepTimeHash()
if err != nil {
log.Errorf("failed to detect initial sleep time: %v", err)
select {
case <-ctx.Done():
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
return
case <-time.After(3 * time.Second):
continue
}
}
log.Debugf("initial wakeup hash: %d", initialHash)
break
}
break
}
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info("context canceled, stopping wakeUpListen")
return
case <-ticker.C:
newHash, err := readSleepTimeHash()
if err != nil {
log.Errorf("failed to read sleep time hash: %v", err)
continue
}
if newHash == initialHash {
log.Tracef("no wakeup detected")
continue
}
upOut, err := exec.Command("uptime").Output()
if err != nil {
log.Errorf("failed to run uptime command: %v", err)
upOut = []byte("unknown")
}
log.Infof("Wakeup detected: %d -> %d, uptime: %s", initialHash, newHash, upOut)
return
}
}
}
func readSleepTimeHash() (uint32, error) {
cmd := exec.Command("sysctl", "kern.sleeptime")
out, err := cmd.Output()
if err != nil {
return 0, fmt.Errorf("failed to run sysctl: %w", err)
}
h, err := hash(out)
if err != nil {
return 0, fmt.Errorf("failed to compute hash: %w", err)
}
return h, nil
}
func hash(data []byte) (uint32, error) {
hasher := fnv.New32a() // Create a new 32-bit FNV-1a hasher
if _, err := hasher.Write(data); err != nil {
return 0, err
}
return hasher.Sum32(), nil
}

View File

@@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
event := make(chan struct{}, 1) event := make(chan struct{}, 1)
go nw.checkChanges(ctx, event, nexthop4, nexthop6) go nw.checkChanges(ctx, event, nexthop4, nexthop6)
log.Infof("start watching for network changes")
// debounce changes // debounce changes
timer := time.NewTimer(0) timer := time.NewTimer(0)
timer.Stop() timer.Stop()

View File

@@ -19,11 +19,10 @@ type SRWatcher struct {
signalClient chNotifier signalClient chNotifier
relayManager chNotifier relayManager chNotifier
listeners map[chan struct{}]struct{} listeners map[chan struct{}]struct{}
mu sync.Mutex mu sync.Mutex
iFaceDiscover stdnet.ExternalIFaceDiscover iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig ice.Config iceConfig ice.Config
cancelIceMonitor context.CancelFunc cancelIceMonitor context.CancelFunc
} }

View File

@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
if isController(w.config) { if isController(w.config) {
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else { } else {
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} }

View File

@@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct {
DisablePromptLogin bool DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior // LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
} }
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it // GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it

View File

@@ -2,6 +2,8 @@ package relay
import ( import (
"context" "context"
"crypto/sha256"
"errors"
"fmt" "fmt"
"net" "net"
"sync" "sync"
@@ -15,6 +17,15 @@ import (
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
const (
DefaultCacheTTL = 20 * time.Second
probeTimeout = 6 * time.Second
)
var (
ErrCheckInProgress = errors.New("probe check is already in progress")
)
// ProbeResult holds the info about the result of a relay probe request // ProbeResult holds the info about the result of a relay probe request
type ProbeResult struct { type ProbeResult struct {
URI string URI string
@@ -22,8 +33,164 @@ type ProbeResult struct {
Addr string Addr string
} }
type StunTurnProbe struct {
cacheResults []ProbeResult
cacheTimestamp time.Time
cacheKey string
cacheTTL time.Duration
probeInProgress bool
probeDone chan struct{}
mu sync.Mutex
}
func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe {
return &StunTurnProbe{
cacheTTL: cacheTTL,
}
}
func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
cacheKey := generateCacheKey(stuns, turns)
p.mu.Lock()
if p.probeInProgress {
doneChan := p.probeDone
p.mu.Unlock()
select {
case <-ctx.Done():
log.Debugf("Context cancelled while waiting for probe results")
return createErrorResults(stuns, turns)
case <-doneChan:
return p.getCachedResults(cacheKey, stuns, turns)
}
}
p.probeInProgress = true
probeDone := make(chan struct{})
p.probeDone = probeDone
p.mu.Unlock()
p.doProbe(ctx, stuns, turns, cacheKey)
close(probeDone)
return p.getCachedResults(cacheKey, stuns, turns)
}
// ProbeAll probes all given servers asynchronously and returns the results
func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
cacheKey := generateCacheKey(stuns, turns)
p.mu.Lock()
if results := p.checkCache(cacheKey); results != nil {
p.mu.Unlock()
return results
}
if p.probeInProgress {
p.mu.Unlock()
return createErrorResults(stuns, turns)
}
p.probeInProgress = true
probeDone := make(chan struct{})
p.probeDone = probeDone
log.Infof("started new probe for STUN, TURN servers")
go func() {
p.doProbe(ctx, stuns, turns, cacheKey)
close(probeDone)
}()
p.mu.Unlock()
timer := time.NewTimer(1300 * time.Millisecond)
defer timer.Stop()
select {
case <-ctx.Done():
log.Debugf("Context cancelled while waiting for probe results")
return createErrorResults(stuns, turns)
case <-probeDone:
// when the probe is return fast, return the results right away
return p.getCachedResults(cacheKey, stuns, turns)
case <-timer.C:
// if the probe takes longer than 1.3s, return error results to avoid blocking
return createErrorResults(stuns, turns)
}
}
func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult {
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
age := time.Since(p.cacheTimestamp)
if age < p.cacheTTL {
results := append([]ProbeResult(nil), p.cacheResults...)
log.Debugf("returning cached probe results (age: %v)", age)
return results
}
}
return nil
}
func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
p.mu.Lock()
defer p.mu.Unlock()
if p.cacheKey == cacheKey && len(p.cacheResults) > 0 {
return append([]ProbeResult(nil), p.cacheResults...)
}
return createErrorResults(stuns, turns)
}
func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) {
defer func() {
p.mu.Lock()
p.probeInProgress = false
p.mu.Unlock()
}()
results := make([]ProbeResult, len(stuns)+len(turns))
var wg sync.WaitGroup
for i, uri := range stuns {
wg.Add(1)
go func(idx int, stunURI *stun.URI) {
defer wg.Done()
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
results[idx].URI = stunURI.String()
results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI)
}(i, uri)
}
stunOffset := len(stuns)
for i, uri := range turns {
wg.Add(1)
go func(idx int, turnURI *stun.URI) {
defer wg.Done()
probeCtx, cancel := context.WithTimeout(ctx, probeTimeout)
defer cancel()
results[idx].URI = turnURI.String()
results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI)
}(stunOffset+i, uri)
}
wg.Wait()
p.mu.Lock()
p.cacheResults = results
p.cacheTimestamp = time.Now()
p.cacheKey = cacheKey
p.mu.Unlock()
log.Debug("Stored new probe results in cache")
}
// ProbeSTUN tries binding to the given STUN uri and acquiring an address // ProbeSTUN tries binding to the given STUN uri and acquiring an address
func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() { defer func() {
if probeErr != nil { if probeErr != nil {
log.Debugf("stun probe error from %s: %s", uri, probeErr) log.Debugf("stun probe error from %s: %s", uri, probeErr)
@@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
} }
// ProbeTURN tries allocating a session from the given TURN URI // ProbeTURN tries allocating a session from the given TURN URI
func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) {
defer func() { defer func() {
if probeErr != nil { if probeErr != nil {
log.Debugf("turn probe error from %s: %s", uri, probeErr) log.Debugf("turn probe error from %s: %s", uri, probeErr)
@@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
return relayConn.LocalAddr().String(), nil return relayConn.LocalAddr().String(), nil
} }
// ProbeAll probes all given servers asynchronously and returns the results func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult {
func ProbeAll( total := len(stuns) + len(turns)
ctx context.Context, results := make([]ProbeResult, total)
fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error),
relays []*stun.URI,
) []ProbeResult {
results := make([]ProbeResult, len(relays))
var wg sync.WaitGroup allURIs := append(append([]*stun.URI{}, stuns...), turns...)
for i, uri := range relays { for i, uri := range allURIs {
ctx, cancel := context.WithTimeout(ctx, 6*time.Second) results[i] = ProbeResult{
defer cancel() URI: uri.String(),
Err: ErrCheckInProgress,
wg.Add(1) }
go func(res *ProbeResult, stunURI *stun.URI) {
defer wg.Done()
res.URI = stunURI.String()
res.Addr, res.Err = fn(ctx, stunURI)
}(&results[i], uri)
} }
wg.Wait()
return results return results
} }
func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string {
h := sha256.New()
for _, uri := range stuns {
h.Write([]byte(uri.String()))
}
for _, uri := range turns {
h.Write([]byte(uri.String()))
}
return fmt.Sprintf("%x", h.Sum(nil))
}

View File

@@ -1,6 +1,7 @@
package common package common
import ( import (
"sync/atomic"
"time" "time"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
@@ -25,4 +26,5 @@ type HandlerParams struct {
UseNewDNSRoute bool UseNewDNSRoute bool
Firewall manager.Manager Firewall manager.Manager
FakeIPManager *fakeip.Manager FakeIPManager *fakeip.Manager
ForwarderPort *atomic.Uint32
} }

View File

@@ -8,6 +8,7 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
@@ -18,7 +19,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/common"
@@ -55,6 +55,7 @@ type DnsInterceptor struct {
peerStore *peerstore.Store peerStore *peerstore.Store
firewall firewall.Manager firewall firewall.Manager
fakeIPManager *fakeip.Manager fakeIPManager *fakeip.Manager
forwarderPort *atomic.Uint32
} }
func New(params common.HandlerParams) *DnsInterceptor { func New(params common.HandlerParams) *DnsInterceptor {
@@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor {
firewall: params.Firewall, firewall: params.Firewall,
fakeIPManager: params.FakeIPManager, fakeIPManager: params.FakeIPManager,
interceptedDomains: make(domainMap), interceptedDomains: make(domainMap),
forwarderPort: params.ForwarderPort,
} }
} }
@@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load()))
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel() defer cancel()

View File

@@ -10,6 +10,7 @@ import (
"runtime" "runtime"
"slices" "slices"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@@ -23,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client" "github.com/netbirdio/netbird/client/internal/routemanager/client"
@@ -54,6 +56,7 @@ type Manager interface {
SetRouteChangeListener(listener listener.NetworkChangeListener) SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string InitialRouteRange() []string
SetFirewall(firewall.Manager) error SetFirewall(firewall.Manager) error
SetDNSForwarderPort(port uint16)
Stop(stateManager *statemanager.Manager) Stop(stateManager *statemanager.Manager)
} }
@@ -78,6 +81,7 @@ type DefaultManager struct {
ctx context.Context ctx context.Context
stop context.CancelFunc stop context.CancelFunc
mux sync.Mutex mux sync.Mutex
shutdownWg sync.WaitGroup
clientNetworks map[route.HAUniqueID]*client.Watcher clientNetworks map[route.HAUniqueID]*client.Watcher
routeSelector *routeselector.RouteSelector routeSelector *routeselector.RouteSelector
serverRouter *server.Router serverRouter *server.Router
@@ -101,12 +105,13 @@ type DefaultManager struct {
disableServerRoutes bool disableServerRoutes bool
activeRoutes map[route.HAUniqueID]client.RouteHandler activeRoutes map[route.HAUniqueID]client.RouteHandler
fakeIPManager *fakeip.Manager fakeIPManager *fakeip.Manager
dnsForwarderPort atomic.Uint32
} }
func NewManager(config ManagerConfig) *DefaultManager { func NewManager(config ManagerConfig) *DefaultManager {
mCTX, cancel := context.WithCancel(config.Context) mCTX, cancel := context.WithCancel(config.Context)
notifier := notifier.NewNotifier() notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier) sysOps := systemops.New(config.WGInterface, notifier)
if runtime.GOOS == "windows" && config.WGInterface != nil { if runtime.GOOS == "windows" && config.WGInterface != nil {
nbnet.SetVPNInterfaceName(config.WGInterface.Name()) nbnet.SetVPNInterfaceName(config.WGInterface.Name())
@@ -130,6 +135,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
disableServerRoutes: config.DisableServerRoutes, disableServerRoutes: config.DisableServerRoutes,
activeRoutes: make(map[route.HAUniqueID]client.RouteHandler), activeRoutes: make(map[route.HAUniqueID]client.RouteHandler),
} }
dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort))
useNoop := netstack.IsEnabled() || config.DisableClientRoutes useNoop := netstack.IsEnabled() || config.DisableClientRoutes
dm.setupRefCounters(useNoop) dm.setupRefCounters(useNoop)
@@ -270,9 +276,15 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
return nil return nil
} }
// SetDNSForwarderPort sets the DNS forwarder port for route handlers
func (m *DefaultManager) SetDNSForwarderPort(port uint16) {
m.dnsForwarderPort.Store(uint32(port))
}
// Stop stops the manager watchers and clean firewall rules // Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
m.stop() m.stop()
m.shutdownWg.Wait()
if m.serverRouter != nil { if m.serverRouter != nil {
m.serverRouter.CleanUp() m.serverRouter.CleanUp()
} }
@@ -345,6 +357,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
UseNewDNSRoute: m.useNewDNSRoute, UseNewDNSRoute: m.useNewDNSRoute,
Firewall: m.firewall, Firewall: m.firewall,
FakeIPManager: m.fakeIPManager, FakeIPManager: m.fakeIPManager,
ForwarderPort: &m.dnsForwarderPort,
} }
handler := client.HandlerFromRoute(params) handler := client.HandlerFromRoute(params)
if err := handler.AddRoute(m.ctx); err != nil { if err := handler.AddRoute(m.ctx); err != nil {
@@ -474,7 +487,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
} }
clientNetworkWatcher := client.NewWatcher(config) clientNetworkWatcher := client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.Start() m.shutdownWg.Add(1)
go func() {
defer m.shutdownWg.Done()
clientNetworkWatcher.Start()
}()
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes}) clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
} }
@@ -516,7 +533,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
} }
clientNetworkWatcher = client.NewWatcher(config) clientNetworkWatcher = client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.Start() m.shutdownWg.Add(1)
go func() {
defer m.shutdownWg.Done()
clientNetworkWatcher.Start()
}()
} }
update := client.RoutesUpdate{ update := client.RoutesUpdate{
UpdateSerial: updateSerial, UpdateSerial: updateSerial,

View File

@@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error {
panic("implement me") panic("implement me")
} }
// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface
func (m *MockManager) SetDNSForwarderPort(port uint16) {
}
// Stop mock implementation of Stop from Manager interface // Stop mock implementation of Stop from Manager interface
func (m *MockManager) Stop(stateManager *statemanager.Manager) { func (m *MockManager) Stop(stateManager *statemanager.Manager) {
if m.StopFunc != nil { if m.StopFunc != nil {

View File

@@ -0,0 +1,8 @@
//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd)
package systemops
// FlushMarkedRoutes is a no-op on non-BSD platforms.
func (r *SysOps) FlushMarkedRoutes() error {
return nil
}

View File

@@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string {
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
sysops := NewSysOps(nil, nil) sysOps := New(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable)
sysops.refCounter.LoadData((*ExclusionCounter)(s)) sysOps.refCounter.LoadData((*ExclusionCounter)(s))
return sysops.refCounter.Flush() return sysOps.refCounter.Flush()
} }
func (s *ShutdownState) MarshalJSON() ([]byte, error) { func (s *ShutdownState) MarshalJSON() ([]byte, error) {

View File

@@ -83,7 +83,7 @@ type SysOps struct {
localSubnetsCacheTime time.Time localSubnetsCacheTime time.Time
} }
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{ return &SysOps{
wgInterface: wgInterface, wgInterface: wgInterface,
notifier: notifier, notifier: notifier,

View File

@@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) {
_, intf = setupDummyInterface(t) _, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf} nexthop = Nexthop{netip.Addr{}, intf}
r := NewSysOps(nil, nil) r := New(nil, nil)
var wg sync.WaitGroup var wg sync.WaitGroup
for i := 0; i < 1024; i++ { for i := 0; i < 1024; i++ {
@@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin
nexthop := Nexthop{netip.Addr{}, netIntf} nexthop := Nexthop{netip.Addr{}, netIntf}
r := NewSysOps(nil, nil) r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop) err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table") require.NoError(t, err, "Failed to add route to table")

View File

@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil) r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting() advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting) err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err) require.NoError(t, err)
@@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil) r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting() advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting) err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err) require.NoError(t, err)
@@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) {
assert.NoError(t, wgInterface.Close()) assert.NoError(t, wgInterface.Close())
}) })
r := NewSysOps(wgInterface, nil) r := New(wgInterface, nil)
advancedRouting := nbnet.AdvancedRouting() advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting) err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")

View File

@@ -7,19 +7,39 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"strconv" "strconv"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/net/route" "golang.org/x/net/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
)
var routeProtoFlag int
func init() {
switch os.Getenv(envRouteProtoFlag) {
case "2":
routeProtoFlag = unix.RTF_PROTO2
case "3":
routeProtoFlag = unix.RTF_PROTO3
default:
routeProtoFlag = unix.RTF_PROTO1
}
}
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }
@@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout
return r.cleanupRefCounter(stateManager) return r.cleanupRefCounter(stateManager)
} }
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
func (r *SysOps) FlushMarkedRoutes() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
flushedCount := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
routeInfo, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("Skipping route flush: %v", err)
continue
}
if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() {
continue
}
nexthop := Nexthop{
IP: routeInfo.Gw,
Intf: routeInfo.Interface,
}
if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err))
continue
}
flushedCount++
log.Debugf("Flushed marked route: %s", routeInfo.Dst)
}
if flushedCount > 0 {
log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
return r.routeSocket(unix.RTM_ADD, prefix, nexthop) return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
} }
@@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func(
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
msg = &route.RouteMessage{ msg = &route.RouteMessage{
Type: action, Type: action,
Flags: unix.RTF_UP, Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
Seq: r.getSeq(), Seq: r.getSeq(),
} }

View File

@@ -9,8 +9,6 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -128,13 +126,11 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
defer rs.mu.RUnlock() defer rs.mu.RUnlock()
if rs.deselectAll { if rs.deselectAll {
log.Debugf("Route %s not selected (deselect all)", routeID)
return false return false
} }
_, deselected := rs.deselectedRoutes[routeID] _, deselected := rs.deselectedRoutes[routeID]
isSelected := !deselected isSelected := !deselected
log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected)
return isSelected return isSelected
} }

View File

@@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
data, err := os.ReadFile(m.filePath) data, err := os.ReadFile(m.filePath)
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
log.Debug("state file does not exist") log.Debugf("state file %s does not exist", m.filePath)
return nil, nil // nolint:nilnil return nil, nil // nolint:nilnil
} }
return nil, fmt.Errorf("read state file: %w", err) return nil, fmt.Errorf("read state file: %w", err)

View File

@@ -0,0 +1,59 @@
package winregistry
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows/registry"
)
var (
advapi = syscall.NewLazyDLL("advapi32.dll")
regCreateKeyExW = advapi.NewProc("RegCreateKeyExW")
)
const (
// Registry key options
regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted
regOptionVolatile = 0x1 // Key is not preserved when system is rebooted
// Registry disposition values
regCreatedNewKey = 0x1
regOpenedExistingKey = 0x2
)
// CreateVolatileKey creates a volatile registry key named path under open key root.
// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed.
// The access parameter specifies the access rights for the key to be created.
//
// Volatile keys are stored in memory and are automatically deleted when the system is shut down.
// This provides automatic cleanup without requiring manual registry maintenance.
func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) {
pathPtr, err := syscall.UTF16PtrFromString(path)
if err != nil {
return 0, false, err
}
var (
handle syscall.Handle
disposition uint32
)
ret, _, _ := regCreateKeyExW.Call(
uintptr(root),
uintptr(unsafe.Pointer(pathPtr)),
0, // reserved
0, // class
uintptr(regOptionVolatile), // options - volatile key
uintptr(access), // desired access
0, // security attributes
uintptr(unsafe.Pointer(&handle)),
uintptr(unsafe.Pointer(&disposition)),
)
if ret != 0 {
return 0, false, syscall.Errno(ret)
}
return registry.Key(handle), disposition == regOpenedExistingKey, nil
}

View File

@@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string {
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false) oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
if err != nil { if err != nil {
return err.Error() return err.Error()
} }

View File

@@ -17,8 +17,7 @@ type Conn struct {
ID hooks.ConnectionID ID hooks.ConnectionID
} }
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection // Close overrides the net.Conn Close method to execute all registered hooks after closing the connection.
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
func (c *Conn) Close() error { func (c *Conn) Close() error {
return closeConn(c.ID, c.Conn) return closeConn(c.ID, c.Conn)
} }
@@ -29,7 +28,7 @@ type TCPConn struct {
ID hooks.ConnectionID ID hooks.ConnectionID
} }
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. // Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection.
func (c *TCPConn) Close() error { func (c *TCPConn) Close() error {
return closeConn(c.ID, c.TCPConn) return closeConn(c.ID, c.TCPConn)
} }
@@ -37,13 +36,16 @@ func (c *TCPConn) Close() error {
// closeConn is a helper function to close connections and execute close hooks. // closeConn is a helper function to close connections and execute close hooks.
func closeConn(id hooks.ConnectionID, conn io.Closer) error { func closeConn(id hooks.ConnectionID, conn io.Closer) error {
err := conn.Close() err := conn.Close()
cleanupConnID(id)
return err
}
// cleanupConnID executes close hooks for a connection ID.
func cleanupConnID(id hooks.ConnectionID) {
closeHooks := hooks.GetCloseHooks() closeHooks := hooks.GetCloseHooks()
for _, hook := range closeHooks { for _, hook := range closeHooks {
if err := hook(id); err != nil { if err := hook(id); err != nil {
log.Errorf("Error executing close hook: %v", err) log.Errorf("Error executing close hook: %v", err)
} }
} }
return err
} }

View File

@@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro
} }
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
} }
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err) log.Errorf("failed to close connection: %v", err)
} }

View File

@@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
conn, err := d.Dialer.DialContext(ctx, network, address) conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil { if err != nil {
cleanupConnID(connID)
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
} }
@@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str
ips, err := resolver.LookupIPAddr(ctx, host) ips, err := resolver.LookupIPAddr(ctx, host)
if err != nil { if err != nil {
return fmt.Errorf("failed to resolve address %s: %w", address, err) return fmt.Errorf("resolve address %s: %w", address, err)
} }
log.Debugf("Dialer resolved IPs for %s: %v", address, ips) log.Debugf("Dialer resolved IPs for %s: %v", address, ips)

View File

@@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.PacketConn.WriteTo(b, addr) return c.PacketConn.WriteTo(b, addr)
} }
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. // Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection.
func (c *PacketConn) Close() error { func (c *PacketConn) Close() error {
defer c.seenAddrs.Clear() defer c.seenAddrs.Clear()
return closeConn(c.ID, c.PacketConn) return closeConn(c.ID, c.PacketConn)
@@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
return c.UDPConn.WriteTo(b, addr) return c.UDPConn.WriteTo(b, addr)
} }
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. // Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection.
func (c *UDPConn) Close() error { func (c *UDPConn) Close() error {
defer c.seenAddrs.Clear() defer c.seenAddrs.Clear()
return closeConn(c.ID, c.UDPConn) return closeConn(c.ID, c.UDPConn)

View File

@@ -279,8 +279,10 @@ type LoginRequest struct {
ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"`
Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
unknownFields protoimpl.UnknownFields // hint is used to pre-fill the email/username field during SSO authentication
sizeCache protoimpl.SizeCache Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
} }
func (x *LoginRequest) Reset() { func (x *LoginRequest) Reset() {
@@ -538,6 +540,13 @@ func (x *LoginRequest) GetMtu() int64 {
return 0 return 0
} }
func (x *LoginRequest) GetHint() string {
if x != nil && x.Hint != nil {
return *x.Hint
}
return ""
}
type LoginResponse struct { type LoginResponse struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
@@ -4608,7 +4617,7 @@ var File_daemon_proto protoreflect.FileDescriptor
const file_daemon_proto_rawDesc = "" + const file_daemon_proto_rawDesc = "" +
"\n" + "\n" +
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
"\fEmptyRequest\"\xc3\x0e\n" + "\fEmptyRequest\"\xe5\x0e\n" +
"\fLoginRequest\x12\x1a\n" + "\fLoginRequest\x12\x1a\n" +
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
@@ -4645,7 +4654,8 @@ const file_daemon_proto_rawDesc = "" +
"\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" + "\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" +
"\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" +
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" + "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" +
"\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x10\n" + "\x11_rosenpassEnabledB\x10\n" +
"\x0e_interfaceNameB\x10\n" + "\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x17\n" + "\x0e_wireguardPortB\x17\n" +
@@ -4665,7 +4675,8 @@ const file_daemon_proto_rawDesc = "" +
"\x0e_block_inboundB\x0e\n" + "\x0e_block_inboundB\x0e\n" +
"\f_profileNameB\v\n" + "\f_profileNameB\v\n" +
"\t_usernameB\x06\n" + "\t_usernameB\x06\n" +
"\x04_mtu\"\xb5\x01\n" + "\x04_mtuB\a\n" +
"\x05_hint\"\xb5\x01\n" +
"\rLoginResponse\x12$\n" + "\rLoginResponse\x12$\n" +
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +

View File

@@ -158,6 +158,9 @@ message LoginRequest {
optional string username = 31; optional string username = 31;
optional int64 mtu = 32; optional int64 mtu = 32;
// hint is used to pre-fill the email/username field during SSO authentication
optional string hint = 33;
} }
message LoginResponse { message LoginResponse {

View File

@@ -483,7 +483,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting) state.Set(internal.StatusConnecting)
if msg.SetupKey == "" { if msg.SetupKey == "" {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient) hint := ""
if msg.Hint != nil {
hint = *msg.Hint
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
if err != nil { if err != nil {
state.Set(internal.StatusLoginFailed) state.Set(internal.StatusLoginFailed)
return nil, err return nil, err
@@ -1057,10 +1061,7 @@ func (s *Server) Status(
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
if msg.GetFullPeerStatus { if msg.GetFullPeerStatus {
if msg.ShouldRunProbes { s.runProbes(msg.ShouldRunProbes)
s.runProbes()
}
fullStatus := s.statusRecorder.GetFullStatus() fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory() pbFullStatus.Events = s.statusRecorder.GetEventHistory()
@@ -1070,7 +1071,7 @@ func (s *Server) Status(
return &statusResponse, nil return &statusResponse, nil
} }
func (s *Server) runProbes() { func (s *Server) runProbes(waitForProbeResult bool) {
if s.connectClient == nil { if s.connectClient == nil {
return return
} }
@@ -1081,7 +1082,7 @@ func (s *Server) runProbes() {
} }
if time.Since(s.lastProbe) > probeThreshold { if time.Since(s.lastProbe) > probeThreshold {
if engine.RunHealthProbes() { if engine.RunHealthProbes(waitForProbeResult) {
s.lastProbe = time.Now() s.lastProbe = time.Now()
} }
} }

View File

@@ -10,7 +10,9 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
@@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error {
merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err))
} }
// clean up any remaining routes independently of the state file
if !nbnet.AdvancedRouting() {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
probeRelay "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
@@ -205,15 +206,18 @@ func mapPeers(
localICEEndpoint := "" localICEEndpoint := ""
remoteICEEndpoint := "" remoteICEEndpoint := ""
relayServerAddress := "" relayServerAddress := ""
connType := "P2P" connType := "-"
lastHandshake := time.Time{} lastHandshake := time.Time{}
transferReceived := int64(0) transferReceived := int64(0)
transferSent := int64(0) transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if pbPeerState.Relayed { if isPeerConnected {
connType = "Relayed" connType = "P2P"
if pbPeerState.Relayed {
connType = "Relayed"
}
} }
if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) { if skipDetailByFilters(pbPeerState, pbPeerState.ConnStatus, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter, connType) {
@@ -337,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
for _, relay := range overview.Relays.Details { for _, relay := range overview.Relays.Details {
available := "Available" available := "Available"
reason := "" reason := ""
if !relay.Available { if !relay.Available {
available = "Unavailable" if relay.Error == probeRelay.ErrCheckInProgress.Error() {
reason = fmt.Sprintf(", reason: %s", relay.Error) available = "Checking..."
} else {
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
} }
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
} }
} else { } else {

View File

@@ -31,7 +31,6 @@ import (
"fyne.io/systray" "fyne.io/systray"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
@@ -297,6 +296,8 @@ type serviceClient struct {
mExitNodeDeselectAll *systray.MenuItem mExitNodeDeselectAll *systray.MenuItem
logFile string logFile string
wLoginURL fyne.Window wLoginURL fyne.Window
connectCancel context.CancelFunc
} }
type menuHandler struct { type menuHandler struct {
@@ -593,17 +594,15 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
} }
} }
func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("get client: %v", err) return nil, fmt.Errorf("get daemon client: %w", err)
return nil, err
} }
activeProf, err := s.profileManager.GetActiveProfile() activeProf, err := s.profileManager.GetActiveProfile()
if err != nil { if err != nil {
log.Errorf("get active profile: %v", err) return nil, fmt.Errorf("get active profile: %w", err)
return nil, err
} }
currUser, err := user.Current() currUser, err := user.Current()
@@ -611,84 +610,80 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
return nil, fmt.Errorf("get current user: %w", err) return nil, fmt.Errorf("get current user: %w", err)
} }
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ loginReq := &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &activeProf.Name, ProfileName: &activeProf.Name,
Username: &currUser.Username, Username: &currUser.Username,
}) }
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
if err != nil { if err != nil {
log.Errorf("login to management URL with: %v", err) log.Debugf("failed to get profile state for login hint: %v", err)
return nil, err } else if profileState.Email != "" {
loginReq.Hint = &profileState.Email
}
loginResp, err := conn.Login(ctx, loginReq)
if err != nil {
return nil, fmt.Errorf("login to management: %w", err)
} }
if loginResp.NeedsSSOLogin && openURL { if loginResp.NeedsSSOLogin && openURL {
err = s.handleSSOLogin(loginResp, conn) if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil {
if err != nil { return nil, fmt.Errorf("SSO login: %w", err)
log.Errorf("handle SSO login failed: %v", err)
return nil, err
} }
} }
return loginResp, nil return loginResp, nil
} }
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
err := open.Run(loginResp.VerificationURIComplete) if err := openURL(loginResp.VerificationURIComplete); err != nil {
if err != nil { return fmt.Errorf("open browser: %w", err)
log.Errorf("opening the verification uri in the browser failed: %v", err)
return err
} }
resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil { if err != nil {
log.Errorf("waiting sso login failed with: %v", err) return fmt.Errorf("wait for SSO login: %w", err)
return err
} }
if resp.Email != "" { if resp.Email != "" {
err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
Email: resp.Email, Email: resp.Email,
}) }); err != nil {
if err != nil { log.Debugf("failed to set profile state: %v", err)
log.Warnf("failed to set profile state: %v", err)
} else { } else {
s.mProfile.refresh() s.mProfile.refresh()
} }
} }
return nil return nil
} }
func (s *serviceClient) menuUpClick() error { func (s *serviceClient) menuUpClick(ctx context.Context) error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
systray.SetTemplateIcon(iconErrorMacOS, s.icError) systray.SetTemplateIcon(iconErrorMacOS, s.icError)
log.Errorf("get client: %v", err) return fmt.Errorf("get daemon client: %w", err)
return err
} }
_, err = s.login(true) _, err = s.login(ctx, true)
if err != nil { if err != nil {
log.Errorf("login failed with: %v", err) return fmt.Errorf("login: %w", err)
return err
} }
status, err := conn.Status(s.ctx, &proto.StatusRequest{}) status, err := conn.Status(ctx, &proto.StatusRequest{})
if err != nil { if err != nil {
log.Errorf("get service status: %v", err) return fmt.Errorf("get status: %w", err)
return err
} }
if status.Status == string(internal.StatusConnected) { if status.Status == string(internal.StatusConnected) {
log.Warnf("already connected")
return nil return nil
} }
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil {
log.Errorf("up service: %v", err) return fmt.Errorf("start connection: %w", err)
return err
} }
return nil return nil
@@ -698,24 +693,20 @@ func (s *serviceClient) menuDownClick() error {
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("get client: %v", err) return fmt.Errorf("get daemon client: %w", err)
return err
} }
status, err := conn.Status(s.ctx, &proto.StatusRequest{}) status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil { if err != nil {
log.Errorf("get service status: %v", err) return fmt.Errorf("get status: %w", err)
return err
} }
if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) { if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) {
log.Warnf("already down")
return nil return nil
} }
if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil { if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
log.Errorf("down service: %v", err) return fmt.Errorf("stop connection: %w", err)
return err
} }
return nil return nil
@@ -851,6 +842,7 @@ func (s *serviceClient) onTrayReady() {
newProfileMenuArgs := &newProfileMenuArgs{ newProfileMenuArgs := &newProfileMenuArgs{
ctx: s.ctx, ctx: s.ctx,
serviceClient: s,
profileManager: s.profileManager, profileManager: s.profileManager,
eventHandler: s.eventHandler, eventHandler: s.eventHandler,
profileMenuItem: profileMenuItem, profileMenuItem: profileMenuItem,
@@ -1382,7 +1374,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
return return
} }
resp, err := s.login(false) resp, err := s.login(ctx, false)
if err != nil { if err != nil {
log.Errorf("failed to fetch login URL: %v", err) log.Errorf("failed to fetch login URL: %v", err)
return return
@@ -1402,7 +1394,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
return return
} }
_, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) _, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
if err != nil { if err != nil {
log.Errorf("Waiting sso login failed with: %v", err) log.Errorf("Waiting sso login failed with: %v", err)
label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.") label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.")
@@ -1410,7 +1402,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
} }
label.SetText("Re-authentication successful.\nReconnecting") label.SetText("Re-authentication successful.\nReconnecting")
status, err := conn.Status(s.ctx, &proto.StatusRequest{}) status, err := conn.Status(ctx, &proto.StatusRequest{})
if err != nil { if err != nil {
log.Errorf("get service status: %v", err) log.Errorf("get service status: %v", err)
return return
@@ -1423,7 +1415,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
return return
} }
_, err = conn.Up(s.ctx, &proto.UpRequest{}) _, err = conn.Up(ctx, &proto.UpRequest{})
if err != nil { if err != nil {
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.") label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")
log.Errorf("Reconnecting failed with: %v", err) log.Errorf("Reconnecting failed with: %v", err)
@@ -1487,6 +1479,10 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
} }
func openURL(url string) error { func openURL(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
var err error var err error
switch runtime.GOOS { switch runtime.GOOS {
case "windows": case "windows":

View File

@@ -18,6 +18,7 @@ import (
"github.com/skratchdot/open-golang/open" "github.com/skratchdot/open-golang/open"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
uptypes "github.com/netbirdio/netbird/upload-server/types" uptypes "github.com/netbirdio/netbird/upload-server/types"
@@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData(
return "", err return "", err
} }
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil { if err != nil {
log.Warnf("Failed to get post-up status: %v", err) log.Warnf("Failed to get post-up status: %v", err)
@@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string var postUpStatusOutput string
if postUpStatus != nil { if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
@@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string var preDownStatusOutput string
if preDownStatus != nil { if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
@@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
return nil, fmt.Errorf("get client: %v", err) return nil, fmt.Errorf("get client: %v", err)
} }
pm := profilemanager.NewProfileManager()
var profName string
if activeProf, err := pm.GetActiveProfile(); err == nil {
profName = activeProf.Name
}
statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil { if err != nil {
log.Warnf("failed to get status for debug bundle: %v", err) log.Warnf("failed to get status for debug bundle: %v", err)
@@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string var statusOutput string
if statusResp != nil { if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = nbstatus.ParseToFullDetailSummary(overview) statusOutput = nbstatus.ParseToFullDetailSummary(overview)
} }

Some files were not shown because too many files have changed in this diff Show More