Compare commits

...

36 Commits

Author SHA1 Message Date
Viktor Liu
417fa6e833 Change default interface name from wt0 to nb0 2025-07-21 17:58:56 +02:00
Bethuel Mmbaga
a7af15c4fc [management] Fix group resource count mismatch in policy (#4182) 2025-07-21 15:26:06 +03:00
Viktor Liu
d6ed9c037e [client] Fix bind exclusion routes (#4154) 2025-07-21 12:13:21 +02:00
Ali Amer
40fdeda838 [client] add new filter-by-connection-type flag (#4010)
introduces a new flag --filter-by-connection-type to the status command.
It allows users to filter peers by connection type (P2P or Relayed) in both JSON and detailed views.

Input validation is added in parseFilters() to ensure proper usage, and --detail is auto-enabled if no output format is specified (consistent with other filters).
2025-07-21 11:55:17 +02:00
Zoltan Papp
f6e9d755e4 [client, relay] The openConn function no longer blocks the relayAddress function call (#4180)
The openConn function no longer blocks the relayAddress function call in manager layer
2025-07-21 09:46:53 +02:00
Maycon Santos
08fd460867 [management] Add validate flow response (#4172)
This PR adds a validate flow response feature to the management server by integrating an IntegratedValidator component. The main purpose is to enable validation of PKCE authorization flows through an integrated validator interface.

- Adds a new ValidateFlowResponse method to the IntegratedValidator interface
- Integrates the validator into the management server to validate PKCE authorization flows
- Updates dependency version for management-integrations
2025-07-18 12:18:52 +02:00
Pascal Fischer
4f74509d55 [management] fix index creation if exist on mysql (#4150) 2025-07-16 15:07:31 +02:00
Maycon Santos
58185ced16 [misc] add forum post and update sign pipeline (#4155)
use old git-town version
2025-07-16 14:10:28 +02:00
Pedro Maia Costa
e67f44f47c [client] fix test (#4156) 2025-07-16 12:09:38 +02:00
Zoltan Papp
b524f486e2 [client] Fix/nil relayed address (#4153)
Fix nil pointer in Relay conn address

Meanwhile, we create a relayed net.Conn struct instance, it is possible to set the relayedURL to nil.

panic: value method github.com/netbirdio/netbird/relay/client.RelayAddr.String called using nil *RelayAddr pointer

Fix relayed URL variable protection
Protect the channel closing
2025-07-16 00:00:18 +02:00
Zoltan Papp
0dab03252c [client, relay-server] Feature/relay notification (#4083)
- Clients now subscribe to peer status changes.
- The server manages and maintains these subscriptions.
- Replaced raw string peer IDs with a custom peer ID type for better type safety and clarity.
2025-07-15 10:43:42 +02:00
iisteev
e49bcc343d [client] Avoid parsing NB_NETSTACK_SKIP_PROXY if empty (#4145)
Signed-off-by: iisteev <isteevan.shetoo@is-info.fr>
2025-07-13 15:42:48 +02:00
Zoltan Papp
3e6eede152 [client] Fix elapsed time calculation when machine is in sleep mode (#4140) 2025-07-12 11:10:45 +02:00
Vlad
a76c8eafb4 [management] sync calls to UpdateAccountPeers from BufferUpdateAccountPeers (#4137)
---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com>
2025-07-11 12:37:14 +03:00
Pedro Maia Costa
2b9f331980 always suffix ephemeral peer name (#4138) 2025-07-11 10:29:10 +01:00
Viktor Liu
a7ea881900 [client] Add rotated logs flag for debug bundle generation (#4100) 2025-07-10 16:13:53 +02:00
Vlad
8632dd15f1 [management] added cleanupWindow for collecting several ephemeral peers to delete (#4130)
---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
Co-authored-by: Pedro Costa <550684+pnmcosta@users.noreply.github.com>
2025-07-10 15:21:01 +02:00
Zoltan Papp
e3b40ba694 Update cli description of lazy connection (#4133) 2025-07-10 15:00:58 +02:00
Zoltan Papp
e59d75d56a Nil check in iface configurer (#4132) 2025-07-10 14:24:20 +02:00
Zoltan Papp
408f423adc [client] Disable pidfd check on Android 11 and below (#4127)
Disable pidfd check on Android 11 and below

On Android 11 (SDK <= 30) and earlier, pidfd-related system calls
are blocked by seccomp policies, causing SIGSYS crashes.

This change overrides `checkPidfdOnce` to return an error on
affected versions, preventing the use of unsupported pidfd features.
2025-07-09 22:16:08 +02:00
Misha Bragin
f17dd3619c [misc] update image in README.md (#4122) 2025-07-09 15:49:09 +02:00
Bethuel Mmbaga
969f1ed59a [management] Remove deleted user peers from groups on user deletion (#4121)
Refactors peer deletion to centralize group cleanup logic, ensuring deleted peers are consistently removed from all groups in one place.

- Removed redundant group removal code from DefaultAccountManager.DeletePeer
- Added group removal logic inside deletePeers to handle both single and multiple peer deletions
2025-07-09 10:14:10 +03:00
M. Essam
768ba24fda [management,rest] Add name/ip filters to peer management rest client (#4112) 2025-07-08 18:08:13 +02:00
Zoltan Papp
8942c40fde [client] Fix nil pointer exception in lazy connection (#4109)
Remove unused variable
2025-07-06 15:13:14 +02:00
Zoltan Papp
fbb1b55beb [client] refactor lazy detection (#4050)
This PR introduces a new inactivity package responsible for monitoring peer activity and notifying when peers become inactive.
Introduces a new Signal message type to close the peer connection after the idle timeout is reached.
Periodically checks the last activity of registered peers via a Bind interface.
Notifies via a channel when peers exceed a configurable inactivity threshold.
Default settings
DefaultInactivityThreshold is set to 15 minutes, with a minimum allowed threshold of 1 minute.

Limitations
This inactivity check does not support kernel WireGuard integration. In kernel–user space communication, the user space side will always be responsible for closing the connection.
2025-07-04 19:52:27 +02:00
Viktor Liu
77ec32dd6f [client] Implement dns routes for Android (#3989) 2025-07-04 16:43:11 +02:00
Bethuel Mmbaga
8c09a55057 [management] Log user id on account mismatch (#4101) 2025-07-04 10:51:58 +03:00
Pedro Maia Costa
f603ddf35e management: fix store get account peers without lock (#4092) 2025-07-04 08:44:08 +01:00
Krzysztof Nazarewski (kdn)
996b8c600c [management] replace invalid user with a clear error message about mismatched logins (#4097) 2025-07-03 16:36:36 +02:00
Maycon Santos
c4ed11d447 [client] Avoid logging setup keys on error message (#3962) 2025-07-03 16:22:18 +02:00
Viktor Liu
9afbecb7ac [client] Use unique sequence numbers for bsd routes (#4081)
updates the route manager on Unix to use a unique, incrementing sequence number for each route message instead of a fixed value.

Replace the static Seq: 1 with a call to r.getSeq()
Add an atomic seq field and the getSeq method in SysOps
2025-07-03 09:02:53 +02:00
Maycon Santos
2c81cf2c1e [management] Add account onboarding (#4084)
This PR introduces a new onboarding feature to handle such flows in the dashboard by defining an AccountOnboarding model, persisting it in the store, exposing CRUD operations in the manager and HTTP handlers, and updating API schemas and tests accordingly.

Add AccountOnboarding struct and embed it in Account
Extend Store and DefaultAccountManager with onboarding methods and SQL migrations
Update HTTP handlers, API types, OpenAPI spec, and add end-to-end tests
2025-07-03 09:01:32 +02:00
Pascal Fischer
551cb4e467 [management] expect specific error types on registration with setup key (#4094) 2025-07-02 20:04:28 +02:00
Maycon Santos
57961afe95 [doc] Add forum link (#4093)
* Add forum link

* Add forum link
2025-07-02 18:40:07 +02:00
Pascal Fischer
22678bce7f [management] add uniqueness constraint for peer ip and label and optimize generation (#4042) 2025-07-02 18:13:10 +02:00
Maycon Santos
6c633497bc [management] fix network update test for delete policy (#4086)
when adding a peer we calculate the network map an account using backpressure functions and some updates might arrive around the time we are deleting a policy.

This change ensures we wait enough time for the updates from add peer to be sent and read before continuing with the test logic
2025-07-02 12:25:31 +02:00
197 changed files with 6791 additions and 2436 deletions

View File

@@ -16,6 +16,6 @@ jobs:
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: git-town/action@v1 - uses: git-town/action@v1.2.1
with: with:
skip-single-stacks: true skip-single-stacks: true

View File

@@ -43,7 +43,7 @@ jobs:
- name: gomobile init - name: gomobile init
run: gomobile init run: gomobile init
- name: build android netbird lib - name: build android netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-checklinkname=0 -X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env: env:
CGO_ENABLED: 0 CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620 ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.20" SIGN_PIPE_VER: "v0.0.21"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -231,3 +231,17 @@ jobs:
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'
post_on_forum:
runs-on: ubuntu-latest
continue-on-error: true
needs: [trigger_signer]
steps:
- uses: Codixer/discourse-topic-github-release-action@v2.0.1
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases

View File

@@ -14,6 +14,9 @@
<br> <br>
<a href="https://docs.netbird.io/slack-url"> <a href="https://docs.netbird.io/slack-url">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<a href="https://forum.netbird.io">
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
</a> </a>
<br> <br>
<a href="https://gurubase.io/g/netbird"> <a href="https://gurubase.io/g/netbird">
@@ -29,13 +32,13 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
<br/> <br/>
</strong> </strong>
<br> <br>
<a href="https://github.com/netbirdio/kubernetes-operator"> <a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
New: NetBird Kubernetes Operator New: NetBird terraform provider
</a> </a>
</p> </p>
@@ -47,10 +50,9 @@
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure. **Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
### Open-Source Network Security in a Single Platform ### Open Source Network Security in a Single Platform
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### NetBird on Lawrence Systems (Video) ### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) [![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)

View File

@@ -64,7 +64,9 @@ type Client struct {
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
execWorkaround(androidSDKVersion)
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket) net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{ return &Client{
cfgFile: cfgFile, cfgFile: cfgFile,
@@ -203,8 +205,10 @@ func (c *Client) Networks() *NetworkArray {
continue continue
} }
if routes[0].IsDynamic() { r := routes[0]
continue netStr := r.Network.String()
if r.IsDynamic() {
netStr = r.Domains.SafeString()
} }
peer, err := c.recorder.GetPeer(routes[0].Peer) peer, err := c.recorder.GetPeer(routes[0].Peer)
@@ -214,7 +218,7 @@ func (c *Client) Networks() *NetworkArray {
} }
network := Network{ network := Network{
Name: string(id), Name: string(id),
Network: routes[0].Network.String(), Network: netStr,
Peer: peer.FQDN, Peer: peer.FQDN,
Status: peer.ConnStatus.String(), Status: peer.ConnStatus.String(),
} }

26
client/android/exec.go Normal file
View File

@@ -0,0 +1,26 @@
//go:build android
package android
import (
"fmt"
_ "unsafe"
)
// https://github.com/golang/go/pull/69543/commits/aad6b3b32c81795f86bc4a9e81aad94899daf520
// In Android version 11 and earlier, pidfd-related system calls
// are not allowed by the seccomp policy, which causes crashes due
// to SIGSYS signals.
//go:linkname checkPidfdOnce os.checkPidfdOnce
var checkPidfdOnce func() error
func execWorkaround(androidSDKVersion int) {
if androidSDKVersion > 30 { // above Android 11
return
}
checkPidfdOnce = func() error {
return fmt.Errorf("unsupported Android version")
}
}

View File

@@ -17,10 +17,18 @@ import (
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
) )
const errCloseConnection = "Failed to close connection: %v" const errCloseConnection = "Failed to close connection: %v"
var (
logFileCount uint32
systemInfoFlag bool
uploadBundleFlag bool
uploadBundleURLFlag string
)
var debugCmd = &cobra.Command{ var debugCmd = &cobra.Command{
Use: "debug", Use: "debug",
Short: "Debugging commands", Short: "Debugging commands",
@@ -88,12 +96,13 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
request := &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag), Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag, SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
} }
if debugUploadBundle { if uploadBundleFlag {
request.UploadURL = debugUploadBundleURL request.UploadURL = uploadBundleURLFlag
} }
resp, err := client.DebugBundle(cmd.Context(), request) resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil { if err != nil {
@@ -105,7 +114,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason()) return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
} }
if debugUploadBundle { if uploadBundleFlag {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey()) cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
} }
@@ -223,12 +232,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag)) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: statusOutput, Status: statusOutput,
SystemInfo: debugSystemInfoFlag, SystemInfo: systemInfoFlag,
LogFileCount: logFileCount,
} }
if debugUploadBundle { if uploadBundleFlag {
request.UploadURL = debugUploadBundleURL request.UploadURL = uploadBundleURLFlag
} }
resp, err := client.DebugBundle(cmd.Context(), request) resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil { if err != nil {
@@ -255,7 +265,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason()) return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
} }
if debugUploadBundle { if uploadBundleFlag {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey()) cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
} }
@@ -297,7 +307,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
cmd.PrintErrf("Failed to get status: %v\n", err) cmd.PrintErrf("Failed to get status: %v\n", err)
} else { } else {
statusOutputString = nbstatus.ParseToFullDetailSummary( statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil), nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, ""),
) )
} }
return statusOutputString return statusOutputString
@@ -375,3 +385,15 @@ func generateDebugBundle(config *internal.Config, recorder *peer.Status, connect
} }
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path) log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
} }
func init() {
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
}

View File

@@ -22,7 +22,6 @@ import (
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/upload-server/types"
) )
const ( const (
@@ -38,10 +37,7 @@ const (
serverSSHAllowedFlag = "allow-server-ssh" serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
enableLazyConnectionFlag = "enable-lazy-connection" enableLazyConnectionFlag = "enable-lazy-connection"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
) )
var ( var (
@@ -75,10 +71,7 @@ var (
autoConnectDisabled bool autoConnectDisabled bool
extraIFaceBlackList []string extraIFaceBlackList []string
anonymizeFlag bool anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
debugUploadBundle bool
debugUploadBundleURL string
lazyConnEnabled bool lazyConnEnabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
@@ -184,11 +177,8 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.") upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
} }
// SetupCloseHandler handles SIGTERM signal and exits with success // SetupCloseHandler handles SIGTERM signal and exits with success

View File

@@ -26,6 +26,7 @@ var (
statusFilter string statusFilter string
ipsFilterMap map[string]struct{} ipsFilterMap map[string]struct{}
prefixNamesFilterMap map[string]struct{} prefixNamesFilterMap map[string]struct{}
connectionTypeFilter string
) )
var statusCmd = &cobra.Command{ var statusCmd = &cobra.Command{
@@ -45,6 +46,7 @@ func init() {
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
} }
func statusFunc(cmd *cobra.Command, args []string) error { func statusFunc(cmd *cobra.Command, args []string) error {
@@ -89,7 +91,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap) var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter)
var statusOutputString string var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
@@ -156,6 +158,15 @@ func parseFilters() error {
enableDetailFlagWhenFilterFlag() enableDetailFlagWhenFilterFlag()
} }
switch strings.ToLower(connectionTypeFilter) {
case "", "p2p", "relayed":
if strings.ToLower(connectionTypeFilter) != "" {
enableDetailFlagWhenFilterFlag()
}
default:
return fmt.Errorf("wrong connection-type filter, should be one of P2P|Relayed, got: %s", connectionTypeFilter)
}
return nil return nil
} }

View File

@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -104,6 +104,12 @@ type Manager struct {
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
blockRule firewall.Rule blockRule firewall.Rule
// Internal 1:1 DNAT
dnatEnabled atomic.Bool
dnatMappings map[netip.Addr]netip.Addr
dnatMutex sync.RWMutex
dnatBiMap *biDNATMap
} }
// decoder for packages // decoder for packages
@@ -189,6 +195,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger, flowLogger: flowLogger,
netstack: netstack.IsEnabled(), netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding, localForwarding: enableLocalForwarding,
dnatMappings: make(map[netip.Addr]netip.Addr),
} }
m.routingEnabled.Store(false) m.routingEnabled.Store(false)
@@ -519,22 +526,6 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// UpdateSet updates the rule destinations associated with the given set // UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating. // by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
@@ -581,14 +572,14 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
return nil return nil
} }
// DropOutgoing filter outgoing packets // FilterOutBound filters outgoing packets
func (m *Manager) DropOutgoing(packetData []byte, size int) bool { func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size) return m.filterOutbound(packetData, size)
} }
// DropIncoming filter incoming packets // FilterInbound filters incoming packets
func (m *Manager) DropIncoming(packetData []byte, size int) bool { func (m *Manager) FilterInbound(packetData []byte, size int) bool {
return m.dropFilter(packetData, size) return m.filterInbound(packetData, size)
} }
// UpdateLocalIPs updates the list of local IPs // UpdateLocalIPs updates the list of local IPs
@@ -596,7 +587,7 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface) return m.localipmanager.UpdateLocalIPs(m.wgIface)
} }
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool { func (m *Manager) filterOutbound(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
@@ -618,8 +609,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
return true return true
} }
// for netflow we keep track even if the firewall is stateless
m.trackOutbound(d, srcIP, dstIP, size) m.trackOutbound(d, srcIP, dstIP, size)
m.translateOutboundDNAT(packetData, d)
return false return false
} }
@@ -723,9 +714,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
return false return false
} }
// dropFilter implements filtering logic for incoming packets. // filterInbound implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte, size int) bool { func (m *Manager) filterInbound(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
@@ -747,8 +738,15 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
return false return false
} }
// For all inbound traffic, first check if it matches a tracked connection. if translated := m.translateInboundReverse(packetData, d); translated {
// This must happen before any other filtering because the packets are statefully tracked. // Re-decode after translation to get original addresses
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
return true
}
srcIP, dstIP = m.extractIPs(d)
}
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) { if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
return false return false
} }

View File

@@ -188,13 +188,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection // For stateful scenarios, establish the connection
if sc.stateful { if sc.stateful {
manager.processOutgoingHooks(outbound, 0) manager.filterOutbound(outbound, 0)
} }
// Measure inbound packet processing // Measure inbound packet processing
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0) manager.filterInbound(inbound, 0)
} }
}) })
} }
@@ -220,7 +220,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i], outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP) uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound, 0) manager.filterOutbound(outbound, 0)
} }
// Test packet // Test packet
@@ -228,11 +228,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection // First establish our test connection
manager.processOutgoingHooks(testOut, 0) manager.filterOutbound(testOut, 0)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(testIn, 0) manager.filterInbound(testIn, 0)
} }
}) })
} }
@@ -263,12 +263,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established { if sc.established {
manager.processOutgoingHooks(outbound, 0) manager.filterOutbound(outbound, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0) manager.filterInbound(inbound, 0)
} }
}) })
} }
@@ -426,25 +426,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections // For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") || if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") { (strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound, 0) manager.filterOutbound(outbound, 0)
// For TCP post-handshake, simulate full handshake // For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" { if sc.state == "post_handshake" {
// SYN // SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0) manager.filterOutbound(syn, 0)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0) manager.filterInbound(synack, 0)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0) manager.filterOutbound(ack, 0)
} }
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0) manager.filterInbound(inbound, 0)
} }
}) })
} }
@@ -568,17 +568,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN // Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0) manager.filterOutbound(syn, 0)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0) manager.filterInbound(synack, 0)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0) manager.filterOutbound(ack, 0)
} }
// Prepare test packets simulating bidirectional traffic // Prepare test packets simulating bidirectional traffic
@@ -599,9 +599,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic // Simulate bidirectional traffic
// First outbound data // First outbound data
manager.processOutgoingHooks(outPackets[connIdx], 0) manager.filterOutbound(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring // Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx], 0) manager.filterInbound(inPackets[connIdx], 0)
} }
}) })
} }
@@ -700,19 +700,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Connection establishment // Connection establishment
manager.processOutgoingHooks(p.syn, 0) manager.filterOutbound(p.syn, 0)
manager.dropFilter(p.synAck, 0) manager.filterInbound(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0) manager.filterOutbound(p.ack, 0)
// Data transfer // Data transfer
manager.processOutgoingHooks(p.request, 0) manager.filterOutbound(p.request, 0)
manager.dropFilter(p.response, 0) manager.filterInbound(p.response, 0)
// Connection teardown // Connection teardown
manager.processOutgoingHooks(p.finClient, 0) manager.filterOutbound(p.finClient, 0)
manager.dropFilter(p.ackServer, 0) manager.filterInbound(p.ackServer, 0)
manager.dropFilter(p.finServer, 0) manager.filterInbound(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0) manager.filterOutbound(p.ackClient, 0)
} }
}) })
} }
@@ -760,15 +760,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ { for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0) manager.filterOutbound(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0) manager.filterInbound(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0) manager.filterOutbound(ack, 0)
} }
// Pre-generate test packets // Pre-generate test packets
@@ -790,8 +790,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++ counter++
// Simulate bidirectional traffic // Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx], 0) manager.filterOutbound(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx], 0) manager.filterInbound(inPackets[connIdx], 0)
} }
}) })
}) })
@@ -879,17 +879,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Full connection lifecycle // Full connection lifecycle
manager.processOutgoingHooks(p.syn, 0) manager.filterOutbound(p.syn, 0)
manager.dropFilter(p.synAck, 0) manager.filterInbound(p.synAck, 0)
manager.processOutgoingHooks(p.ack, 0) manager.filterOutbound(p.ack, 0)
manager.processOutgoingHooks(p.request, 0) manager.filterOutbound(p.request, 0)
manager.dropFilter(p.response, 0) manager.filterInbound(p.response, 0)
manager.processOutgoingHooks(p.finClient, 0) manager.filterOutbound(p.finClient, 0)
manager.dropFilter(p.ackServer, 0) manager.filterInbound(p.ackServer, 0)
manager.dropFilter(p.finServer, 0) manager.filterInbound(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient, 0) manager.filterOutbound(p.ackClient, 0)
} }
}) })
}) })

View File

@@ -462,7 +462,7 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) { t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443) packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.DropIncoming(packet, 0) isDropped := manager.FilterInbound(packet, 0)
require.True(t, isDropped, "Packet should be dropped when no rules exist") require.True(t, isDropped, "Packet should be dropped when no rules exist")
}) })
@@ -509,7 +509,7 @@ func TestPeerACLFiltering(t *testing.T) {
}) })
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.DropIncoming(packet, 0) isDropped := manager.FilterInbound(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped) require.Equal(t, tc.shouldBeBlocked, isDropped)
}) })
} }
@@ -1233,7 +1233,7 @@ func TestRouteACLFiltering(t *testing.T) {
srcIP := netip.MustParseAddr(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed // testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// to the forwarder // to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed) require.Equal(t, tc.shouldPass, isAllowed)

View File

@@ -321,7 +321,7 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if m.dropFilter(buf.Bytes(), 0) { if m.filterInbound(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted") t.Errorf("expected packet to be accepted")
return return
} }
@@ -447,7 +447,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Test hook gets called // Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes(), 0) result := manager.filterOutbound(buf.Bytes(), 0)
require.True(t, result) require.True(t, result)
require.True(t, hookCalled) require.True(t, hookCalled)
@@ -457,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4) err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err) require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes(), 0) result = manager.filterOutbound(buf.Bytes(), 0)
require.False(t, result) require.False(t, result)
} }
@@ -553,7 +553,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Process outbound packet and verify connection tracking // Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0) drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped") require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked // Verify connection was tracked
@@ -620,7 +620,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints { for _, cp := range checkPoints {
time.Sleep(cp.sleep) time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), 0) drop = manager.filterInbound(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description) require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists // If the connection should still be valid, verify it exists
@@ -669,7 +669,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
// Create a new outbound connection for invalid tests // Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0) drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped") require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases { for _, tc := range invalidCases {
@@ -691,7 +691,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify the invalid packet is dropped // Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), 0) drop = manager.filterInbound(testBuf.Bytes(), 0)
require.True(t, drop, tc.description) require.True(t, drop, tc.description)
}) })
} }

View File

@@ -0,0 +1,408 @@
package uspfilter
import (
"encoding/binary"
"errors"
"fmt"
"net/netip"
"github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
func ipv4Checksum(header []byte) uint16 {
if len(header) < 20 {
return 0
}
var sum1, sum2 uint32
// Parallel processing - unroll and compute two sums simultaneously
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
// Skip checksum field at [10:12]
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
sum := sum1 + sum2
// Handle remaining bytes for headers > 20 bytes
for i := 20; i < len(header)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
}
if len(header)%2 == 1 {
sum += uint32(header[len(header)-1]) << 8
}
// Optimized carry fold - single iteration handles most cases
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
func icmpChecksum(data []byte) uint16 {
var sum1, sum2, sum3, sum4 uint32
i := 0
// Process 16 bytes at once with 4 parallel accumulators
for i <= len(data)-16 {
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
i += 16
}
sum := sum1 + sum2 + sum3 + sum4
// Handle remaining bytes
for i < len(data)-1 {
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
i += 2
}
if len(data)%2 == 1 {
sum += uint32(data[len(data)-1]) << 8
}
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
type biDNATMap struct {
forward map[netip.Addr]netip.Addr
reverse map[netip.Addr]netip.Addr
}
func newBiDNATMap() *biDNATMap {
return &biDNATMap{
forward: make(map[netip.Addr]netip.Addr),
reverse: make(map[netip.Addr]netip.Addr),
}
}
func (b *biDNATMap) set(original, translated netip.Addr) {
b.forward[original] = translated
b.reverse[translated] = original
}
func (b *biDNATMap) delete(original netip.Addr) {
if translated, exists := b.forward[original]; exists {
delete(b.forward, original)
delete(b.reverse, translated)
}
}
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
translated, exists := b.forward[original]
return translated, exists
}
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
original, exists := b.reverse[translated]
return original, exists
}
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
return fmt.Errorf("invalid IP addresses")
}
if m.localipmanager.IsLocalIP(translatedAddr) {
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
}
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
// Initialize both maps together if either is nil
if m.dnatMappings == nil || m.dnatBiMap == nil {
m.dnatMappings = make(map[netip.Addr]netip.Addr)
m.dnatBiMap = newBiDNATMap()
}
m.dnatMappings[originalAddr] = translatedAddr
m.dnatBiMap.set(originalAddr, translatedAddr)
if len(m.dnatMappings) == 1 {
m.dnatEnabled.Store(true)
}
return nil
}
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
m.dnatMutex.Lock()
defer m.dnatMutex.Unlock()
if _, exists := m.dnatMappings[originalAddr]; !exists {
return fmt.Errorf("mapping not found for: %s", originalAddr)
}
delete(m.dnatMappings, originalAddr)
m.dnatBiMap.delete(originalAddr)
if len(m.dnatMappings) == 0 {
m.dnatEnabled.Store(false)
}
return nil
}
// getDNATTranslation returns the translated address if a mapping exists
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return addr, false
}
m.dnatMutex.RLock()
translated, exists := m.dnatBiMap.getTranslated(addr)
m.dnatMutex.RUnlock()
return translated, exists
}
// findReverseDNATMapping finds original address for return traffic
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
if !m.dnatEnabled.Load() {
return translatedAddr, false
}
m.dnatMutex.RLock()
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
m.dnatMutex.RUnlock()
return original, exists
}
// translateOutboundDNAT applies DNAT translation to outbound packets
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
translatedIP, exists := m.getDNATTranslation(dstIP)
if !exists {
return false
}
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
m.logger.Error("Failed to rewrite packet destination: %v", err)
return false
}
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
return true
}
// translateInboundReverse applies reverse DNAT to inbound return traffic
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
if !m.dnatEnabled.Load() {
return false
}
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
return false
}
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
originalIP, exists := m.findReverseDNATMapping(srcIP)
if !exists {
return false
}
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
m.logger.Error("Failed to rewrite packet source: %v", err)
return false
}
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
return true
}
// rewritePacketDestination replaces destination IP in the packet
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldDst [4]byte
copy(oldDst[:], packetData[16:20])
newDst := newIP.As4()
copy(packetData[16:20], newDst[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
// rewritePacketSource replaces the source IP address in the packet
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
return ErrIPv4Only
}
var oldSrc [4]byte
copy(oldSrc[:], packetData[12:16])
newSrc := newIP.As4()
copy(packetData[12:16], newSrc[:])
ipHeaderLen := int(d.ip4.IHL) * 4
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
return fmt.Errorf("invalid IP header length")
}
binary.BigEndian.PutUint16(packetData[10:12], 0)
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
if len(d.decoded) > 1 {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeUDP:
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
case layers.LayerTypeICMPv4:
m.updateICMPChecksum(packetData, ipHeaderLen)
}
}
return nil
}
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
tcpStart := ipHeaderLen
if len(packetData) < tcpStart+18 {
return
}
checksumOffset := tcpStart + 16
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
udpStart := ipHeaderLen
if len(packetData) < udpStart+8 {
return
}
checksumOffset := udpStart + 6
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
if oldChecksum == 0 {
return
}
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
}
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
icmpStart := ipHeaderLen
if len(packetData) < icmpStart+8 {
return
}
icmpData := packetData[icmpStart:]
binary.BigEndian.PutUint16(icmpData[2:4], 0)
checksum := icmpChecksum(icmpData)
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
}
// incrementalUpdate performs incremental checksum update per RFC 1624
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
sum := uint32(^oldChecksum)
// Fast path for IPv4 addresses (4 bytes) - most common case
if len(oldBytes) == 4 && len(newBytes) == 4 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
} else {
// Fallback for other lengths
for i := 0; i < len(oldBytes)-1; i += 2 {
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
}
if len(oldBytes)%2 == 1 {
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
}
for i := 0; i < len(newBytes)-1; i += 2 {
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
}
if len(newBytes)%2 == 1 {
sum += uint32(newBytes[len(newBytes)-1]) << 8
}
}
sum = (sum & 0xFFFF) + (sum >> 16)
if sum > 0xFFFF {
sum++
}
return ^uint16(sum)
}
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}

View File

@@ -0,0 +1,416 @@
package uspfilter
import (
"fmt"
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
)
// BenchmarkDNATTranslation measures the performance of DNAT operations
func BenchmarkDNATTranslation(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
setupDNAT bool
description string
}{
{
name: "tcp_with_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: true,
description: "TCP packet with DNAT translation enabled",
},
{
name: "tcp_without_dnat",
proto: layers.IPProtocolTCP,
setupDNAT: false,
description: "TCP packet without DNAT (baseline)",
},
{
name: "udp_with_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: true,
description: "UDP packet with DNAT translation enabled",
},
{
name: "udp_without_dnat",
proto: layers.IPProtocolUDP,
setupDNAT: false,
description: "UDP packet without DNAT (baseline)",
},
{
name: "icmp_with_dnat",
proto: layers.IPProtocolICMPv4,
setupDNAT: true,
description: "ICMP packet with DNAT translation enabled",
},
{
name: "icmp_without_dnat",
proto: layers.IPProtocolICMPv4,
setupDNAT: false,
description: "ICMP packet without DNAT (baseline)",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup DNAT mapping if needed
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
if sc.setupDNAT {
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
}
// Create test packets
srcIP := netip.MustParseAddr("172.16.0.1")
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
// Pre-establish connection for reverse DNAT test
if sc.setupDNAT {
manager.filterOutbound(outboundPacket, 0)
}
b.ResetTimer()
// Benchmark outbound DNAT translation
b.Run("outbound", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time since translation modifies it
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
manager.filterOutbound(packet, 0)
}
})
// Benchmark inbound reverse DNAT translation
if sc.setupDNAT {
b.Run("inbound_reverse", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Create fresh packet each time since translation modifies it
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
manager.filterInbound(packet, 0)
}
})
}
})
}
}
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup multiple DNAT mappings
numMappings := 100
originalIPs := make([]netip.Addr, numMappings)
translatedIPs := make([]netip.Addr, numMappings)
for i := 0; i < numMappings; i++ {
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
require.NoError(b, err)
}
srcIP := netip.MustParseAddr("172.16.0.1")
// Pre-generate packets
outboundPackets := make([][]byte, numMappings)
inboundPackets := make([][]byte, numMappings)
for i := 0; i < numMappings; i++ {
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
// Establish connections
manager.filterOutbound(outboundPackets[i], 0)
}
b.ResetTimer()
b.Run("concurrent_outbound", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
idx := i % numMappings
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
manager.filterOutbound(packet, 0)
i++
}
})
})
b.Run("concurrent_inbound", func(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
idx := i % numMappings
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
manager.filterInbound(packet, 0)
i++
}
})
})
}
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
func BenchmarkDNATScaling(b *testing.B) {
mappingCounts := []int{1, 10, 100, 1000}
for _, count := range mappingCounts {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
// Setup DNAT mappings
for i := 0; i < count; i++ {
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
}
// Test with the last mapping added (worst case for lookup)
srcIP := netip.MustParseAddr("172.16.0.1")
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
b.ResetTimer()
for i := 0; i < b.N; i++ {
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
manager.filterOutbound(packet, 0)
}
})
}
}
// generateDNATTestPacket creates a test packet for DNAT benchmarking
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
tb.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP.AsSlice(),
DstIP: dstIP.AsSlice(),
Protocol: proto,
}
var transportLayer gopacket.SerializableLayer
switch proto {
case layers.IPProtocolTCP:
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
}
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
transportLayer = tcp
case layers.IPProtocolUDP:
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
transportLayer = udp
case layers.IPProtocolICMPv4:
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
}
transportLayer = icmp
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
require.NoError(tb, err)
return buf.Bytes()
}
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
func BenchmarkChecksumUpdate(b *testing.B) {
// Create test data for checksum calculations
testData := make([]byte, 64) // Typical packet size for checksum testing
for i := range testData {
testData[i] = byte(i)
}
b.Run("ipv4_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
}
})
b.Run("icmp_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = icmpChecksum(testData)
}
})
b.Run("incremental_update", func(b *testing.B) {
oldBytes := []byte{192, 168, 1, 100}
newBytes := []byte{10, 0, 0, 100}
oldChecksum := uint16(0x1234)
for i := 0; i < b.N; i++ {
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
}
})
}
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
// Set logger to error level to reduce noise during benchmarking
manager.SetLogLevel(log.ErrorLevel)
defer func() {
// Restore to info level after benchmark
manager.SetLogLevel(log.InfoLevel)
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
srcIP := netip.MustParseAddr("172.16.0.1")
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(b, err)
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
// Create fresh packet each time to isolate allocation testing
testPacket := make([]byte, len(packet))
copy(testPacket, packet)
// Parse the packet fresh each time to get a clean decoder
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err = d.parser.DecodeLayers(testPacket, &d.decoded)
assert.NoError(b, err)
manager.translateOutboundDNAT(testPacket, d)
}
}
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
func BenchmarkDirectIPExtraction(b *testing.B) {
// Create a test packet
srcIP := netip.MustParseAddr("172.16.0.1")
dstIP := netip.MustParseAddr("192.168.1.100")
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
b.Run("direct_byte_access", func(b *testing.B) {
for i := 0; i < b.N; i++ {
// Direct extraction from packet bytes
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
}
})
b.Run("decoder_extraction", func(b *testing.B) {
// Create decoder once for comparison
d := &decoder{decoded: []gopacket.LayerType{}}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packet, &d.decoded)
assert.NoError(b, err)
for i := 0; i < b.N; i++ {
// Extract using decoder (traditional method)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
_ = dst
}
})
}
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
func BenchmarkChecksumOptimizations(b *testing.B) {
// Create test IPv4 header (20 bytes)
header := make([]byte, 20)
for i := range header {
header[i] = byte(i)
}
// Clear checksum field
header[10] = 0
header[11] = 0
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ipv4Checksum(header)
}
})
// Test incremental checksum updates
oldIP := []byte{192, 168, 1, 100}
newIP := []byte{10, 0, 0, 100}
oldChecksum := uint16(0x1234)
b.Run("optimized_incremental_update", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
}
})
}

View File

@@ -0,0 +1,145 @@
package uspfilter
import (
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/device"
)
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
srcIP := netip.MustParseAddr("172.16.0.1")
// Add DNAT mapping
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(t, err)
testCases := []struct {
name string
protocol layers.IPProtocol
srcPort uint16
dstPort uint16
}{
{"TCP", layers.IPProtocolTCP, 12345, 80},
{"UDP", layers.IPProtocolUDP, 12345, 53},
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Test outbound DNAT translation
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
originalOutbound := make([]byte, len(outboundPacket))
copy(originalOutbound, outboundPacket)
// Process outbound packet (should translate destination)
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
require.True(t, translated, "Outbound packet should be translated")
// Verify destination IP was changed
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
// Test inbound reverse DNAT translation
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
originalInbound := make([]byte, len(inboundPacket))
copy(originalInbound, inboundPacket)
// Process inbound packet (should reverse translate source)
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
require.True(t, reversed, "Inbound packet should be reverse translated")
// Verify source IP was changed back to original
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
// Test that checksums are recalculated correctly
if tc.protocol != layers.IPProtocolICMPv4 {
// For TCP/UDP, verify the transport checksum was updated
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
}
})
}
}
// parsePacket helper to create a decoder for testing
func parsePacket(t testing.TB, packetData []byte) *decoder {
t.Helper()
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
err := d.parser.DecodeLayers(packetData, &d.decoded)
require.NoError(t, err)
return d
}
// TestDNATMappingManagement tests adding/removing DNAT mappings
func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
originalIP := netip.MustParseAddr("192.168.1.100")
translatedIP := netip.MustParseAddr("10.0.0.100")
// Test adding mapping
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
require.NoError(t, err)
// Verify mapping exists
result, exists := manager.getDNATTranslation(originalIP)
require.True(t, exists)
require.Equal(t, translatedIP, result)
// Test reverse lookup
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
require.True(t, exists)
require.Equal(t, originalIP, reverseResult)
// Test removing mapping
err = manager.RemoveInternalDNATMapping(originalIP)
require.NoError(t, err)
// Verify mapping no longer exists
_, exists = manager.getDNATTranslation(originalIP)
require.False(t, exists)
_, exists = manager.findReverseDNATMapping(translatedIP)
require.False(t, exists)
// Test error cases
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
require.Error(t, err, "Should reject invalid original IP")
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
require.Error(t, err, "Should reject invalid translated IP")
err = manager.RemoveInternalDNATMapping(originalIP)
require.Error(t, err, "Should error when removing non-existent mapping")
}

View File

@@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state // will create or update the connection state
dropped := m.processOutgoingHooks(packetData, 0) dropped := m.filterOutbound(packetData, 0)
if dropped { if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else { } else {

View File

@@ -0,0 +1,96 @@
package bind
import (
"net/netip"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/monotime"
)
const (
saveFrequency = int64(5 * time.Second)
)
type PeerRecord struct {
Address netip.AddrPort
LastActivity atomic.Int64 // UnixNano timestamp
}
type ActivityRecorder struct {
mu sync.RWMutex
peers map[string]*PeerRecord // publicKey to PeerRecord map
addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map
}
func NewActivityRecorder() *ActivityRecorder {
return &ActivityRecorder{
peers: make(map[string]*PeerRecord),
addrToPeer: make(map[netip.AddrPort]*PeerRecord),
}
}
// GetLastActivities returns a snapshot of peer last activity
func (r *ActivityRecorder) GetLastActivities() map[string]monotime.Time {
r.mu.RLock()
defer r.mu.RUnlock()
activities := make(map[string]monotime.Time, len(r.peers))
for key, record := range r.peers {
monoTime := record.LastActivity.Load()
activities[key] = monotime.Time(monoTime)
}
return activities
}
// UpsertAddress adds or updates the address for a publicKey
func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) {
r.mu.Lock()
defer r.mu.Unlock()
var record *PeerRecord
record, exists := r.peers[publicKey]
if exists {
delete(r.addrToPeer, record.Address)
record.Address = address
} else {
record = &PeerRecord{
Address: address,
}
record.LastActivity.Store(int64(monotime.Now()))
r.peers[publicKey] = record
}
r.addrToPeer[address] = record
}
func (r *ActivityRecorder) Remove(publicKey string) {
r.mu.Lock()
defer r.mu.Unlock()
if record, exists := r.peers[publicKey]; exists {
delete(r.addrToPeer, record.Address)
delete(r.peers, publicKey)
}
}
// record updates LastActivity for the given address using atomic store
func (r *ActivityRecorder) record(address netip.AddrPort) {
r.mu.RLock()
record, ok := r.addrToPeer[address]
r.mu.RUnlock()
if !ok {
log.Warnf("could not find record for address %s", address)
return
}
now := int64(monotime.Now())
last := record.LastActivity.Load()
if now-last < saveFrequency {
return
}
_ = record.LastActivity.CompareAndSwap(last, now)
}

View File

@@ -0,0 +1,25 @@
package bind
import (
"net/netip"
"testing"
"time"
"github.com/netbirdio/netbird/monotime"
)
func TestActivityRecorder_GetLastActivities(t *testing.T) {
peer := "peer1"
ar := NewActivityRecorder()
ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820"))
activities := ar.GetLastActivities()
p, ok := activities[peer]
if !ok {
t.Fatalf("Expected activity for peer %s, but got none", peer)
}
if monotime.Since(p) > 5*time.Second {
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
}
}

View File

@@ -0,0 +1,15 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
func init() {
listener := nbnet.NewListener()
if listener.ListenConfig.Control != nil {
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
}
}

View File

@@ -1,12 +0,0 @@
package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
)
func init() {
// ControlFns is not thread safe and should only be modified during init.
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
}

View File

@@ -1,6 +1,7 @@
package bind package bind
import ( import (
"encoding/binary"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -15,6 +16,7 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type RecvMessage struct { type RecvMessage struct {
@@ -51,22 +53,24 @@ type ICEBind struct {
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it. closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool closed bool
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault udpMux *UniversalUDPMuxDefault
address wgaddr.Address address wgaddr.Address
activityRecorder *ActivityRecorder
} }
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
RecvChan: make(chan RecvMessage, 1), RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet, transportNet: transportNet,
filterFn: filterFn, filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}), closedChan: make(chan struct{}),
closed: true, closed: true,
address: address, address: address,
activityRecorder: NewActivityRecorder(),
} }
rc := receiverCreator{ rc := receiverCreator{
@@ -100,6 +104,10 @@ func (s *ICEBind) Close() error {
return s.StdNetBind.Close() return s.StdNetBind.Close()
} }
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
return s.activityRecorder
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind // GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock() s.muUDPMux.Lock()
@@ -146,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: conn, UDPConn: nbnet.WrapUDPConn(conn),
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
WGAddress: s.address, WGAddress: s.address,
@@ -199,6 +207,11 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
continue continue
} }
addrPort := msg.Addr.(*net.UDPAddr).AddrPort() addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
if isTransportPkg(msg.Buffers, msg.N) {
s.activityRecorder.record(addrPort)
}
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep) wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep eps[i] = ep
@@ -257,6 +270,13 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
copy(buffs[0], msg.Buffer) copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer) sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint) eps[0] = wgConn.Endpoint(msg.Endpoint)
if isTransportPkg(buffs, sizes[0]) {
if ep, ok := eps[0].(*Endpoint); ok {
c.activityRecorder.record(ep.AddrPort)
}
}
return 1, nil return 1, nil
} }
} }
@@ -272,3 +292,19 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
} }
msgsPool.Put(msgs) msgsPool.Put(msgs)
} }
func isTransportPkg(buffers [][]byte, n int) bool {
// The first buffer should contain at least 4 bytes for type
if len(buffers[0]) < 4 {
return true
}
// WireGuard packet type is a little-endian uint32 at start
packetType := binary.LittleEndian.Uint32(buffers[0][:4])
// Check if packetType matches known WireGuard message types
if packetType == 4 && n > 32 {
return true
}
return false
}

View File

@@ -296,14 +296,20 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
return return
} }
m.addressMapMu.Lock() var allAddresses []string
defer m.addressMapMu.Unlock()
for _, c := range removedConns { for _, c := range removedConns {
addresses := c.getAddresses() addresses := c.getAddresses()
for _, addr := range addresses { allAddresses = append(allAddresses, addresses...)
delete(m.addressMap, addr) }
}
m.addressMapMu.Lock()
for _, addr := range allAddresses {
delete(m.addressMap, addr)
}
m.addressMapMu.Unlock()
for _, addr := range allAddresses {
m.notifyAddressRemoval(addr)
} }
} }
@@ -351,14 +357,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
} }
m.addressMapMu.Lock() m.addressMapMu.Lock()
defer m.addressMapMu.Unlock()
existing, ok := m.addressMap[addr] existing, ok := m.addressMap[addr]
if !ok { if !ok {
existing = []*udpMuxedConn{} existing = []*udpMuxedConn{}
} }
existing = append(existing, conn) existing = append(existing, conn)
m.addressMap[addr] = existing m.addressMap[addr] = existing
m.addressMapMu.Unlock()
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
} }
@@ -386,12 +391,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one. // muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections. // We will then forward STUN packets to each of these connections.
m.addressMapMu.Lock() m.addressMapMu.RLock()
var destinationConnList []*udpMuxedConn var destinationConnList []*udpMuxedConn
if storedConns, ok := m.addressMap[addr.String()]; ok { if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...) destinationConnList = append(destinationConnList, storedConns...)
} }
m.addressMapMu.Unlock() m.addressMapMu.RUnlock()
var isIPv6 bool var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {

View File

@@ -0,0 +1,21 @@
//go:build !ios
package bind
import (
nbnet "github.com/netbirdio/netbird/util/net"
)
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
wrapped, ok := m.params.UDPConn.(*UDPConn)
if !ok {
return
}
nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn)
if !ok {
return
}
nbnetConn.RemoveAddress(addr)
}

View File

@@ -0,0 +1,7 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
// wrap UDP connection, process server reflexive messages // wrap UDP connection, process server reflexive messages
// before they are passed to the UDPMux connection handler (connWorker) // before they are passed to the UDPMux connection handler (connWorker)
m.params.UDPConn = &udpConn{ m.params.UDPConn = &UDPConn{
PacketConn: params.UDPConn, PacketConn: params.UDPConn,
mux: m, mux: m,
logger: params.Logger, logger: params.Logger,
@@ -70,7 +70,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress, address: params.WGAddress,
} }
// embed UDPMux
udpMuxParams := UDPMuxParams{ udpMuxParams := UDPMuxParams{
Logger: params.Logger, Logger: params.Logger,
UDPConn: m.params.UDPConn, UDPConn: m.params.UDPConn,
@@ -114,8 +113,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
} }
} }
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets // UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct { type UDPConn struct {
net.PacketConn net.PacketConn
mux *UniversalUDPMuxDefault mux *UniversalUDPMuxDefault
logger logging.LeveledLogger logger logging.LeveledLogger
@@ -125,7 +124,12 @@ type udpConn struct {
address wgaddr.Address address wgaddr.Address
} }
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { // GetPacketConn returns the underlying PacketConn
func (u *UDPConn) GetPacketConn() net.PacketConn {
return u.PacketConn
}
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil { if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
@@ -137,21 +141,21 @@ func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
return u.handleUncachedAddress(b, addr) return u.handleUncachedAddress(b, addr)
} }
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) { func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted { if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr) return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
} }
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) { func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil { if err := u.performFilterCheck(addr); err != nil {
return 0, err return 0, err
} }
return u.PacketConn.WriteTo(b, addr) return u.PacketConn.WriteTo(b, addr)
} }
func (u *udpConn) performFilterCheck(addr net.Addr) error { func (u *UDPConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr) host, err := getHostFromAddr(addr)
if err != nil { if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err) log.Errorf("Failed to get host from address %s: %v", addr, err)

View File

@@ -11,6 +11,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/monotime"
) )
var zeroKey wgtypes.Key var zeroKey wgtypes.Key
@@ -276,3 +278,7 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
} }
return stats, nil return stats, nil
} }
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
return nil
}

View File

@@ -3,4 +3,4 @@
package configurer package configurer
// WgInterfaceDefault is a default interface name of Netbird // WgInterfaceDefault is a default interface name of Netbird
const WgInterfaceDefault = "wt0" const WgInterfaceDefault = "nb0"

View File

@@ -16,6 +16,8 @@ import (
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@@ -36,16 +38,18 @@ const (
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct { type WGUSPConfigurer struct {
device *device.Device device *device.Device
deviceName string deviceName string
activityRecorder *bind.ActivityRecorder
uapiListener net.Listener uapiListener net.Listener
} }
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer { func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
wgCfg := &WGUSPConfigurer{ wgCfg := &WGUSPConfigurer{
device: device, device: device,
deviceName: deviceName, deviceName: deviceName,
activityRecorder: activityRecorder,
} }
wgCfg.startUAPI() wgCfg.startUAPI()
return wgCfg return wgCfg
@@ -87,7 +91,19 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix,
Peers: []wgtypes.PeerConfig{peer}, Peers: []wgtypes.PeerConfig{peer},
} }
return c.device.IpcSet(toWgUserspaceString(config)) if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
return ipcErr
}
if endpoint != nil {
addr, err := netip.ParseAddr(endpoint.IP.String())
if err != nil {
return fmt.Errorf("failed to parse endpoint address: %w", err)
}
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
c.activityRecorder.UpsertAddress(peerKey, addrPort)
}
return nil
} }
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
@@ -104,7 +120,10 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
config := wgtypes.Config{ config := wgtypes.Config{
Peers: []wgtypes.PeerConfig{peer}, Peers: []wgtypes.PeerConfig{peer},
} }
return c.device.IpcSet(toWgUserspaceString(config)) ipcErr := c.device.IpcSet(toWgUserspaceString(config))
c.activityRecorder.Remove(peerKey)
return ipcErr
} }
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
@@ -205,6 +224,10 @@ func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
return parseStatus(c.deviceName, ipcStr) return parseStatus(c.deviceName, ipcStr)
} }
func (c *WGUSPConfigurer) LastActivities() map[string]monotime.Time {
return c.activityRecorder.GetLastActivities()
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() { func (t *WGUSPConfigurer) startUAPI() {
var err error var err error

View File

@@ -79,7 +79,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
// this helps with support for the older NetBird clients that had a hardcoded direct mode // this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics() // t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name) t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
t.device.Close() t.device.Close()

View File

@@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err) return nil, fmt.Errorf("error assigning ip: %s", err)
} }
t.configurer = configurer.NewUSPConfigurer(t.device, t.name) t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
t.device.Close() t.device.Close()

View File

@@ -9,11 +9,11 @@ import (
// PacketFilter interface for firewall abilities // PacketFilter interface for firewall abilities
type PacketFilter interface { type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations // FilterOutbound filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte, size int) bool FilterOutbound(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host // FilterInbound filter incoming packets from external sources to host
DropIncoming(packetData []byte, size int) bool FilterInbound(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
@@ -54,7 +54,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) { if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...) bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...) sizes = append(sizes[:i], sizes[i+1:]...)
n-- n--
@@ -78,7 +78,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs)) filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0 dropped := 0
for _, buf := range bufs { for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:], len(buf)) { if !filter.FilterInbound(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf) filteredBufs = append(filteredBufs, buf)
dropped++ dropped++
} }

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil) tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true) filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil return 1, nil
}) })
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true) filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter

View File

@@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
// this helps with support for the older NetBird clients that had a hardcoded direct mode // this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics() // t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = configurer.NewUSPConfigurer(t.device, t.name) t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
t.device.Close() t.device.Close()

View File

@@ -72,7 +72,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
device.NewLogger(wgLogLevel(), "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
t.configurer = configurer.NewUSPConfigurer(t.device, t.name) t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
_ = tunIface.Close() _ = tunIface.Close()

View File

@@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err) return nil, fmt.Errorf("error assigning ip: %s", err)
} }
t.configurer = configurer.NewUSPConfigurer(t.device, t.name) t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
t.device.Close() t.device.Close()

View File

@@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err) return nil, fmt.Errorf("error assigning ip: %s", err)
} }
t.configurer = configurer.NewUSPConfigurer(t.device, t.name) t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
err = t.configurer.ConfigureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
t.device.Close() t.device.Close()

View File

@@ -8,6 +8,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/monotime"
) )
type WGConfigurer interface { type WGConfigurer interface {
@@ -19,4 +20,5 @@ type WGConfigurer interface {
Close() Close()
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
} }

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
) )
const ( const (
@@ -29,6 +30,11 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault WgInterfaceDefault = configurer.WgInterfaceDefault
) )
var (
// ErrIfaceNotFound is returned when the WireGuard interface is not found
ErrIfaceNotFound = fmt.Errorf("wireguard interface not found")
)
type wgProxyFactory interface { type wgProxyFactory interface {
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Free() error Free() error
@@ -117,6 +123,9 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps) log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
@@ -126,6 +135,9 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAliv
func (w *WGIface) RemovePeer(peerKey string) error { func (w *WGIface) RemovePeer(peerKey string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName()) log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
return w.configurer.RemovePeer(peerKey) return w.configurer.RemovePeer(peerKey)
@@ -135,6 +147,9 @@ func (w *WGIface) RemovePeer(peerKey string) error {
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.AddAllowedIP(peerKey, allowedIP) return w.configurer.AddAllowedIP(peerKey, allowedIP)
@@ -144,6 +159,9 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error { func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
if w.configurer == nil {
return ErrIfaceNotFound
}
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.RemoveAllowedIP(peerKey, allowedIP) return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
@@ -214,10 +232,29 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
// GetStats returns the last handshake time, rx and tx bytes // GetStats returns the last handshake time, rx and tx bytes
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) { func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
if w.configurer == nil {
return nil, ErrIfaceNotFound
}
return w.configurer.GetStats() return w.configurer.GetStats()
} }
func (w *WGIface) LastActivities() map[string]monotime.Time {
w.mu.Lock()
defer w.mu.Unlock()
if w.configurer == nil {
return nil
}
return w.configurer.LastActivities()
}
func (w *WGIface) FullStats() (*configurer.Stats, error) { func (w *WGIface) FullStats() (*configurer.Stats, error) {
if w.configurer == nil {
return nil, ErrIfaceNotFound
}
return w.configurer.FullStats() return w.configurer.FullStats()
} }

View File

@@ -48,32 +48,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
} }
// DropIncoming mocks base method. // FilterInbound mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool { func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1) ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropIncoming indicates an expected call of DropIncoming. // FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call { func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1)
} }
// DropOutgoing mocks base method. // FilterOutbound mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool { func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1) ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropOutgoing indicates an expected call of DropOutgoing. // FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call { func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
} }
// RemovePacketHook mocks base method. // RemovePacketHook mocks base method.

View File

@@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
} }
// DropIncoming mocks base method. // FilterInbound mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0) ret := m.ctrl.Call(m, "FilterInbound", arg0)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropIncoming indicates an expected call of DropIncoming. // FilterInbound indicates an expected call of FilterInbound.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call { func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
} }
// DropOutgoing mocks base method. // FilterOutbound mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0) ret := m.ctrl.Call(m, "FilterOutbound", arg0)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropOutgoing indicates an expected call of DropOutgoing. // FilterOutbound indicates an expected call of FilterOutbound.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call { func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
} }
// SetNetwork mocks base method. // SetNetwork mocks base method.

View File

@@ -41,9 +41,12 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
} }
t.tundev = nsTunDev t.tundev = nsTunDev
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy)) var skipProxy bool
if err != nil { if val := os.Getenv(EnvSkipProxy); val != "" {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err) skipProxy, err = strconv.ParseBool(val)
if err != nil {
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
}
} }
if skipProxy { if skipProxy {
return nsTunDev, tunNet, nil return nsTunDev, tunNet, nil

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
type ProxyBind struct { type ProxyBind struct {
@@ -28,6 +29,17 @@ type ProxyBind struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
}
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
p := &ProxyBind{
Bind: bind,
closeListener: listener.NewCloseListener(),
}
return p
} }
// AddTurnConn adds a new connection to the bind. // AddTurnConn adds a new connection to the bind.
@@ -54,6 +66,10 @@ func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
} }
} }
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyBind) Work() { func (p *ProxyBind) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
return return
@@ -96,6 +112,9 @@ func (p *ProxyBind) close() error {
if p.closed { if p.closed {
return nil return nil
} }
p.closeListener.SetCloseListener(nil)
p.closed = true p.closed = true
p.cancel() p.cancel()
@@ -122,6 +141,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
p.closeListener.Notify()
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return return
} }

View File

@@ -11,6 +11,8 @@ import (
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
@@ -26,6 +28,15 @@ type ProxyWrapper struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
}
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy,
closeListener: listener.NewCloseListener(),
}
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
@@ -43,6 +54,10 @@ func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr return p.wgEndpointAddr
} }
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
func (p *ProxyWrapper) Work() { func (p *ProxyWrapper) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
return return
@@ -77,6 +92,8 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel() e.cancel()
e.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err) return fmt.Errorf("failed to close remote conn: %w", err)
} }
@@ -117,6 +134,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
if ctx.Err() != nil { if ctx.Err() != nil {
return 0, ctx.Err() return 0, ctx.Err()
} }
p.closeListener.Notify()
if !errors.Is(err, io.EOF) { if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
} }

View File

@@ -36,9 +36,8 @@ func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort) return udpProxy.NewWGUDPProxy(w.wgPort)
} }
return &ebpf.ProxyWrapper{ return ebpf.NewProxyWrapper(w.ebpfProxy)
WgeBPFProxy: w.ebpfProxy,
}
} }
func (w *KernelFactory) Free() error { func (w *KernelFactory) Free() error {

View File

@@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
} }
func (w *USPFactory) GetProxy() Proxy { func (w *USPFactory) GetProxy() Proxy {
return &proxyBind.ProxyBind{ return proxyBind.NewProxyBind(w.bind)
Bind: w.bind,
}
} }
func (w *USPFactory) Free() error { func (w *USPFactory) Free() error {

View File

@@ -0,0 +1,19 @@
package listener
type CloseListener struct {
listener func()
}
func NewCloseListener() *CloseListener {
return &CloseListener{}
}
func (c *CloseListener) SetCloseListener(listener func()) {
c.listener = listener
}
func (c *CloseListener) Notify() {
if c.listener != nil {
c.listener()
}
}

View File

@@ -12,4 +12,5 @@ type Proxy interface {
Work() // Work start or resume the proxy Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
CloseConn() error CloseConn() error
SetDisconnectListener(disconnected func())
} }

View File

@@ -98,9 +98,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
t.Errorf("failed to free ebpf proxy: %s", err) t.Errorf("failed to free ebpf proxy: %s", err)
} }
}() }()
proxyWrapper := &ebpf.ProxyWrapper{ proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct { tests = append(tests, struct {
name string name string

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cerrors "github.com/netbirdio/netbird/client/errors" cerrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
) )
// WGUDPProxy proxies // WGUDPProxy proxies
@@ -28,6 +29,8 @@ type WGUDPProxy struct {
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
isStarted bool isStarted bool
closeListener *listener.CloseListener
} }
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation // NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
@@ -35,6 +38,7 @@ func NewWGUDPProxy(wgPort int) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort) log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUDPProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
closeListener: listener.NewCloseListener(),
} }
return p return p
} }
@@ -67,6 +71,10 @@ func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
return endpointUdpAddr return endpointUdpAddr
} }
func (p *WGUDPProxy) SetDisconnectListener(disconnected func()) {
p.closeListener.SetCloseListener(disconnected)
}
// Work starts the proxy or resumes it if it was paused // Work starts the proxy or resumes it if it was paused
func (p *WGUDPProxy) Work() { func (p *WGUDPProxy) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
@@ -111,6 +119,8 @@ func (p *WGUDPProxy) close() error {
if p.closed { if p.closed {
return nil return nil
} }
p.closeListener.SetCloseListener(nil)
p.closed = true p.closed = true
p.cancel() p.cancel()
@@ -141,6 +151,7 @@ func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
p.closeListener.Notify()
log.Debugf("failed to read from wg interface conn: %s", err) log.Debugf("failed to read from wg interface conn: %s", err)
return return
} }

View File

@@ -39,7 +39,7 @@ const (
) )
var defaultInterfaceBlacklist = []string{ var defaultInterfaceBlacklist = []string{
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts", iface.WgInterfaceDefault, "nb", "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo", "Tailscale", "tailscale", "docker", "veth", "br-", "lo",
} }

View File

@@ -12,7 +12,6 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/manager" "github.com/netbirdio/netbird/client/internal/lazyconn/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -26,11 +25,11 @@ import (
// //
// The implementation is not thread-safe; it is protected by engine.syncMsgMux. // The implementation is not thread-safe; it is protected by engine.syncMsgMux.
type ConnMgr struct { type ConnMgr struct {
peerStore *peerstore.Store peerStore *peerstore.Store
statusRecorder *peer.Status statusRecorder *peer.Status
iface lazyconn.WGIface iface lazyconn.WGIface
dispatcher *dispatcher.ConnectionDispatcher enabledLocally bool
enabledLocally bool rosenpassEnabled bool
lazyConnMgr *manager.Manager lazyConnMgr *manager.Manager
@@ -39,12 +38,12 @@ type ConnMgr struct {
lazyCtxCancel context.CancelFunc lazyCtxCancel context.CancelFunc
} }
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr { func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface) *ConnMgr {
e := &ConnMgr{ e := &ConnMgr{
peerStore: peerStore, peerStore: peerStore,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
iface: iface, iface: iface,
dispatcher: dispatcher, rosenpassEnabled: engineConfig.RosenpassEnabled,
} }
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() { if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true e.enabledLocally = true
@@ -64,6 +63,11 @@ func (e *ConnMgr) Start(ctx context.Context) {
return return
} }
if e.rosenpassEnabled {
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
return
}
e.initLazyManager(ctx) e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true) e.statusRecorder.UpdateLazyConnection(true)
} }
@@ -83,7 +87,12 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
return nil return nil
} }
log.Infof("lazy connection manager is enabled by management feature flag") if e.rosenpassEnabled {
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
return nil
}
log.Warnf("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx) e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true) e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager() return e.addPeersToLazyConnManager()
@@ -133,7 +142,7 @@ func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
excludedPeers = append(excludedPeers, lazyPeerCfg) excludedPeers = append(excludedPeers, lazyPeerCfg)
} }
added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers) added := e.lazyConnMgr.ExcludePeer(excludedPeers)
for _, peerID := range added { for _, peerID := range added {
var peerConn *peer.Conn var peerConn *peer.Conn
var exists bool var exists bool
@@ -175,7 +184,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co
PeerConnID: conn.ConnID(), PeerConnID: conn.ConnID(),
Log: conn.Log, Log: conn.Log,
} }
excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg) excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
if err != nil { if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err) conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil { if err := conn.Open(ctx); err != nil {
@@ -201,7 +210,7 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
if !ok { if !ok {
return return
} }
defer conn.Close() defer conn.Close(false)
if !e.isStartedWithLazyMgr() { if !e.isStartedWithLazyMgr() {
return return
@@ -211,23 +220,27 @@ func (e *ConnMgr) RemovePeerConn(peerKey string) {
conn.Log.Infof("removed peer from lazy conn manager") conn.Log.Infof("removed peer from lazy conn manager")
} }
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) { func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) {
conn, ok := e.peerStore.PeerConn(peerKey)
if !ok {
return nil, false
}
if !e.isStartedWithLazyMgr() { if !e.isStartedWithLazyMgr() {
return conn, true return
} }
if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found { if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found {
conn.Log.Infof("activated peer from inactive state")
if err := conn.Open(ctx); err != nil { if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err) conn.Log.Errorf("failed to open connection: %v", err)
} }
} }
return conn, true }
// DeactivatePeer deactivates a peer connection in the lazy connection manager.
// If locally the lazy connection is disabled, we force the peer connection open.
func (e *ConnMgr) DeactivatePeer(conn *peer.Conn) {
if !e.isStartedWithLazyMgr() {
return
}
conn.Log.Infof("closing peer connection: remote peer initiated inactive, idle lazy state and sent GOAWAY")
e.lazyConnMgr.DeactivatePeer(conn.ConnID())
} }
func (e *ConnMgr) Close() { func (e *ConnMgr) Close() {
@@ -244,7 +257,7 @@ func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
cfg := manager.Config{ cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(), InactivityThreshold: inactivityThresholdEnv(),
} }
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher) e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface)
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx) e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
@@ -275,7 +288,7 @@ func (e *ConnMgr) addPeersToLazyConnManager() error {
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg) lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
} }
return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs) return e.lazyConnMgr.AddActivePeers(lazyPeerCfgs)
} }
func (e *ConnMgr) closeManager(ctx context.Context) { func (e *ConnMgr) closeManager(ctx context.Context) {

View File

@@ -167,6 +167,7 @@ type BundleGenerator struct {
anonymize bool anonymize bool
clientStatus string clientStatus string
includeSystemInfo bool includeSystemInfo bool
logFileCount uint32
archive *zip.Writer archive *zip.Writer
} }
@@ -175,6 +176,7 @@ type BundleConfig struct {
Anonymize bool Anonymize bool
ClientStatus string ClientStatus string
IncludeSystemInfo bool IncludeSystemInfo bool
LogFileCount uint32
} }
type GeneratorDependencies struct { type GeneratorDependencies struct {
@@ -185,6 +187,12 @@ type GeneratorDependencies struct {
} }
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator { func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
// Default to 1 log file for backward compatibility when 0 is provided
logFileCount := cfg.LogFileCount
if logFileCount == 0 {
logFileCount = 1
}
return &BundleGenerator{ return &BundleGenerator{
anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()), anonymizer: anonymize.NewAnonymizer(anonymize.DefaultAddresses()),
@@ -196,6 +204,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
anonymize: cfg.Anonymize, anonymize: cfg.Anonymize,
clientStatus: cfg.ClientStatus, clientStatus: cfg.ClientStatus,
includeSystemInfo: cfg.IncludeSystemInfo, includeSystemInfo: cfg.IncludeSystemInfo,
logFileCount: logFileCount,
} }
} }
@@ -561,32 +570,8 @@ func (g *BundleGenerator) addLogfile() error {
return fmt.Errorf("add client log file to zip: %w", err) return fmt.Errorf("add client log file to zip: %w", err)
} }
// add latest rotated log file // add rotated log files based on logFileCount
pattern := filepath.Join(logDir, "client-*.log.gz") g.addRotatedLogFiles(logDir)
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
} else if len(files) > 0 {
// pick the file with the latest ModTime
sort.Slice(files, func(i, j int) bool {
fi, err := os.Stat(files[i])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
return false
}
fj, err := os.Stat(files[j])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
return false
}
return fi.ModTime().Before(fj.ModTime())
})
latest := files[len(files)-1]
name := filepath.Base(latest)
if err := g.addSingleLogFileGz(latest, name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}
stdErrLogPath := filepath.Join(logDir, errorLogFile) stdErrLogPath := filepath.Join(logDir, errorLogFile)
stdoutLogPath := filepath.Join(logDir, stdoutLogFile) stdoutLogPath := filepath.Join(logDir, stdoutLogFile)
@@ -670,6 +655,52 @@ func (g *BundleGenerator) addSingleLogFileGz(logPath, targetName string) error {
return nil return nil
} }
// addRotatedLogFiles adds rotated log files to the bundle based on logFileCount
func (g *BundleGenerator) addRotatedLogFiles(logDir string) {
if g.logFileCount == 0 {
return
}
pattern := filepath.Join(logDir, "client-*.log.gz")
files, err := filepath.Glob(pattern)
if err != nil {
log.Warnf("failed to glob rotated logs: %v", err)
return
}
if len(files) == 0 {
return
}
// sort files by modification time (newest first)
sort.Slice(files, func(i, j int) bool {
fi, err := os.Stat(files[i])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[i], err)
return false
}
fj, err := os.Stat(files[j])
if err != nil {
log.Warnf("failed to stat rotated log %s: %v", files[j], err)
return false
}
return fi.ModTime().After(fj.ModTime())
})
// include up to logFileCount rotated files
maxFiles := int(g.logFileCount)
if maxFiles > len(files) {
maxFiles = len(files)
}
for i := 0; i < maxFiles; i++ {
name := filepath.Base(files[i])
if err := g.addSingleLogFileGz(files[i], name); err != nil {
log.Warnf("failed to add rotated log %s: %v", name, err)
}
}
}
func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error { func (g *BundleGenerator) addFileToZip(reader io.Reader, filename string) error {
header := &zip.FileHeader{ header := &zip.FileHeader{
Name: filename, Name: filename,

View File

@@ -464,7 +464,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes() packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any())

View File

@@ -38,7 +38,6 @@ import (
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
@@ -62,7 +61,6 @@ import (
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto" sProto "github.com/netbirdio/netbird/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
) )
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -139,9 +137,6 @@ type Engine struct {
connMgr *ConnMgr connMgr *ConnMgr
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
// rpManager is a Rosenpass manager // rpManager is a Rosenpass manager
rpManager *rosenpass.Manager rpManager *rosenpass.Manager
@@ -175,8 +170,7 @@ type Engine struct {
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error) sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server sshServer nbssh.Server
statusRecorder *peer.Status statusRecorder *peer.Status
peerConnDispatcher *dispatcher.ConnectionDispatcher
firewall firewallManager.Manager firewall firewallManager.Manager
routeManager routemanager.Manager routeManager routemanager.Manager
@@ -383,7 +377,13 @@ func (e *Engine) Start() error {
} }
e.stateManager.Start() e.stateManager.Start()
initialRoutes, dnsServer, err := e.newDnsServer() initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
if err != nil {
e.close()
return fmt.Errorf("read initial settings: %w", err)
}
dnsServer, err := e.newDnsServer(dnsConfig)
if err != nil { if err != nil {
e.close() e.close()
return fmt.Errorf("create dns server: %w", err) return fmt.Errorf("create dns server: %w", err)
@@ -400,16 +400,13 @@ func (e *Engine) Start() error {
InitialRoutes: initialRoutes, InitialRoutes: initialRoutes,
StateManager: e.stateManager, StateManager: e.stateManager,
DNSServer: dnsServer, DNSServer: dnsServer,
DNSFeatureFlag: dnsFeatureFlag,
PeerStore: e.peerStore, PeerStore: e.peerStore,
DisableClientRoutes: e.config.DisableClientRoutes, DisableClientRoutes: e.config.DisableClientRoutes,
DisableServerRoutes: e.config.DisableServerRoutes, DisableServerRoutes: e.config.DisableServerRoutes,
}) })
beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err := e.routeManager.Init(); err != nil {
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
} else {
e.beforePeerHook = beforePeerHook
e.afterPeerHook = afterPeerHook
} }
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
@@ -451,9 +448,7 @@ func (e *Engine) Start() error {
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
} }
e.peerConnDispatcher = dispatcher.NewConnectionDispatcher() e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface)
e.connMgr = NewConnMgr(e.config, e.statusRecorder, e.peerStore, wgIface, e.peerConnDispatcher)
e.connMgr.Start(e.ctx) e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
@@ -488,9 +483,9 @@ func (e *Engine) createFirewall() error {
} }
func (e *Engine) initFirewall() error { func (e *Engine) initFirewall() error {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil { if err := e.routeManager.SetFirewall(e.firewall); err != nil {
e.close() e.close()
return fmt.Errorf("enable server router: %w", err) return fmt.Errorf("set firewall: %w", err)
} }
if e.config.BlockLANAccess { if e.config.BlockLANAccess {
@@ -1009,8 +1004,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update dns server, err: %v", err) log.Errorf("failed to update dns server, err: %v", err)
} }
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
// apply routes first, route related actions might depend on routing being enabled // apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes()) routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1021,6 +1014,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes)) log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
} }
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil { if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update routes: %v", err) log.Errorf("failed to update routes: %v", err)
} }
@@ -1255,14 +1249,10 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
} }
if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists { if exists := e.connMgr.AddPeerConn(e.ctx, peerKey, conn); exists {
conn.Close() conn.Close(false)
return fmt.Errorf("peer already exists: %s", peerKey) return fmt.Errorf("peer already exists: %s", peerKey)
} }
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
conn.AddAfterRemovePeerHook(e.afterPeerHook)
}
return nil return nil
} }
@@ -1302,13 +1292,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
} }
serviceDependencies := peer.ServiceDependencies{ serviceDependencies := peer.ServiceDependencies{
StatusRecorder: e.statusRecorder, StatusRecorder: e.statusRecorder,
Signaler: e.signaler, Signaler: e.signaler,
IFaceDiscover: e.mobileDep.IFaceDiscover, IFaceDiscover: e.mobileDep.IFaceDiscover,
RelayManager: e.relayManager, RelayManager: e.relayManager,
SrWatcher: e.srWatcher, SrWatcher: e.srWatcher,
Semaphore: e.connSemaphore, Semaphore: e.connSemaphore,
PeerConnDispatcher: e.peerConnDispatcher,
} }
peerConn, err := peer.NewConn(config, serviceDependencies) peerConn, err := peer.NewConn(config, serviceDependencies)
if err != nil { if err != nil {
@@ -1331,11 +1320,16 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
conn, ok := e.connMgr.OnSignalMsg(e.ctx, msg.Key) conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok { if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key) return fmt.Errorf("wrongly addressed message %s", msg.Key)
} }
msgType := msg.GetBody().GetType()
if msgType != sProto.Body_GO_IDLE {
e.connMgr.ActivatePeer(e.ctx, conn)
}
switch msg.GetBody().Type { switch msg.GetBody().Type {
case sProto.Body_OFFER: case sProto.Body_OFFER:
remoteCred, err := signal.UnMarshalCredential(msg) remoteCred, err := signal.UnMarshalCredential(msg)
@@ -1392,6 +1386,8 @@ func (e *Engine) receiveSignalEvents() {
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes()) go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
case sProto.Body_MODE: case sProto.Body_MODE:
case sProto.Body_GO_IDLE:
e.connMgr.DeactivatePeer(conn)
} }
return nil return nil
@@ -1489,7 +1485,12 @@ func (e *Engine) close() {
} }
} }
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, error) {
if runtime.GOOS != "android" {
// nolint:nilnil
return nil, nil, false, nil
}
info := system.GetInfo(e.ctx) info := system.GetInfo(e.ctx)
info.SetFlags( info.SetFlags(
e.config.RosenpassEnabled, e.config.RosenpassEnabled,
@@ -1506,11 +1507,12 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
netMap, err := e.mgmClient.GetNetworkMap(info) netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, false, err
} }
routes := toRoutes(netMap.GetRoutes()) routes := toRoutes(netMap.GetRoutes())
dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network) dnsCfg := toDNSConfig(netMap.GetDNSConfig(), e.wgInterface.Address().Network)
return routes, &dnsCfg, nil dnsFeatureFlag := toDNSFeatureFlag(netMap)
return routes, &dnsCfg, dnsFeatureFlag, nil
} }
func (e *Engine) newWgIface() (*iface.WGIface, error) { func (e *Engine) newWgIface() (*iface.WGIface, error) {
@@ -1558,18 +1560,14 @@ func (e *Engine) wgInterfaceCreate() (err error) {
return err return err
} }
func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
// due to tests where we are using a mocked version of the DNS server // due to tests where we are using a mocked version of the DNS server
if e.dnsServer != nil { if e.dnsServer != nil {
return nil, e.dnsServer, nil return e.dnsServer, nil
} }
switch runtime.GOOS { switch runtime.GOOS {
case "android": case "android":
routes, dnsConfig, err := e.readInitialSettings()
if err != nil {
return nil, nil, err
}
dnsServer := dns.NewDefaultServerPermanentUpstream( dnsServer := dns.NewDefaultServerPermanentUpstream(
e.ctx, e.ctx,
e.wgInterface, e.wgInterface,
@@ -1580,19 +1578,19 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
e.config.DisableDNS, e.config.DisableDNS,
) )
go e.mobileDep.DnsReadyListener.OnReady() go e.mobileDep.DnsReadyListener.OnReady()
return routes, dnsServer, nil return dnsServer, nil
case "ios": case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS) dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
return nil, dnsServer, nil return dnsServer, nil
default: default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS) dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
return nil, dnsServer, nil return dnsServer, nil
} }
} }

View File

@@ -36,7 +36,6 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
@@ -53,6 +52,7 @@ import (
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/monotime"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
@@ -97,6 +97,7 @@ type MockWGIface struct {
GetInterfaceGUIDStringFunc func() (string, error) GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy GetProxyFunc func() wgproxy.Proxy
GetNetFunc func() *netstack.Net GetNetFunc func() *netstack.Net
LastActivitiesFunc func() map[string]monotime.Time
} }
func (m *MockWGIface) FullStats() (*configurer.Stats, error) { func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
@@ -187,6 +188,13 @@ func (m *MockWGIface) GetNet() *netstack.Net {
return m.GetNetFunc() return m.GetNetFunc()
} }
func (m *MockWGIface) LastActivities() map[string]monotime.Time {
if m.LastActivitiesFunc != nil {
return m.LastActivitiesFunc()
}
return nil
}
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
_ = util.InitLog("debug", "console") _ = util.InitLog("debug", "console")
code := m.Run() code := m.Run()
@@ -392,7 +400,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
StatusRecorder: engine.statusRecorder, StatusRecorder: engine.statusRecorder,
RelayManager: relayMgr, RelayManager: relayMgr,
}) })
_, _, err = engine.routeManager.Init() err = engine.routeManager.Init()
require.NoError(t, err) require.NoError(t, err)
engine.dnsServer = &dns.MockServer{ engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -404,7 +412,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface, dispatcher.NewConnectionDispatcher()) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
engine.connMgr.Start(ctx) engine.connMgr.Start(ctx)
type testCase struct { type testCase struct {
@@ -793,7 +801,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
engine.routeManager = mockRouteManager engine.routeManager = mockRouteManager
engine.dnsServer = &dns.MockServer{} engine.dnsServer = &dns.MockServer{}
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher()) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
engine.connMgr.Start(ctx) engine.connMgr.Start(ctx)
defer func() { defer func() {
@@ -991,7 +999,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
} }
engine.dnsServer = mockDNSServer engine.dnsServer = mockDNSServer
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface, dispatcher.NewConnectionDispatcher()) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, engine.wgInterface)
engine.connMgr.Start(ctx) engine.connMgr.Start(ctx)
defer func() { defer func() {
@@ -1385,7 +1393,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
if runtime.GOOS == "darwin" { if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i) ifaceName = fmt.Sprintf("utun1%d", i)
} else { } else {
ifaceName = fmt.Sprintf("wt%d", i) ifaceName = fmt.Sprintf("nb%d", i)
} }
wgPort := 33100 + i wgPort := 33100 + i
@@ -1473,6 +1481,10 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()). GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
@@ -1482,7 +1494,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
) )
type wgIfaceBase interface { type wgIfaceBase interface {
@@ -38,4 +39,5 @@ type wgIfaceBase interface {
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetNet() *netstack.Net GetNet() *netstack.Net
FullStats() (*configurer.Stats, error) FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
} }

View File

@@ -13,7 +13,7 @@ import (
// Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking // Listener it is not a thread safe implementation, do not call Close before ReadPackets. It will cause blocking
type Listener struct { type Listener struct {
wgIface lazyconn.WGIface wgIface WgInterface
peerCfg lazyconn.PeerConfig peerCfg lazyconn.PeerConfig
conn *net.UDPConn conn *net.UDPConn
endpoint *net.UDPAddr endpoint *net.UDPAddr
@@ -22,7 +22,7 @@ type Listener struct {
isClosed atomic.Bool // use to avoid error log when closing the listener isClosed atomic.Bool // use to avoid error log when closing the listener
} }
func NewListener(wgIface lazyconn.WGIface, cfg lazyconn.PeerConfig) (*Listener, error) { func NewListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*Listener, error) {
d := &Listener{ d := &Listener{
wgIface: wgIface, wgIface: wgIface,
peerCfg: cfg, peerCfg: cfg,
@@ -48,7 +48,7 @@ func (d *Listener) ReadPackets() {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1)) n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
if err != nil { if err != nil {
if d.isClosed.Load() { if d.isClosed.Load() {
d.peerCfg.Log.Debugf("exit from activity listener") d.peerCfg.Log.Infof("exit from activity listener")
} else { } else {
d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err) d.peerCfg.Log.Errorf("failed to read from activity listener: %s", err)
} }
@@ -59,9 +59,11 @@ func (d *Listener) ReadPackets() {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr) d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue continue
} }
d.peerCfg.Log.Infof("activity detected")
break break
} }
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
if err := d.removeEndpoint(); err != nil { if err := d.removeEndpoint(); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err) d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
} }
@@ -71,7 +73,7 @@ func (d *Listener) ReadPackets() {
} }
func (d *Listener) Close() { func (d *Listener) Close() {
d.peerCfg.Log.Infof("closing listener: %s", d.conn.LocalAddr().String()) d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
d.isClosed.Store(true) d.isClosed.Store(true)
if err := d.conn.Close(); err != nil { if err := d.conn.Close(); err != nil {
@@ -81,7 +83,6 @@ func (d *Listener) Close() {
} }
func (d *Listener) removeEndpoint() error { func (d *Listener) removeEndpoint() error {
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
return d.wgIface.RemovePeer(d.peerCfg.PublicKey) return d.wgIface.RemovePeer(d.peerCfg.PublicKey)
} }

View File

@@ -1,18 +1,27 @@
package activity package activity
import ( import (
"net"
"net/netip"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id" peerid "github.com/netbirdio/netbird/client/internal/peer/id"
) )
type WgInterface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
}
type Manager struct { type Manager struct {
OnActivityChan chan peerid.ConnID OnActivityChan chan peerid.ConnID
wgIface lazyconn.WGIface wgIface WgInterface
peers map[peerid.ConnID]*Listener peers map[peerid.ConnID]*Listener
done chan struct{} done chan struct{}
@@ -20,7 +29,7 @@ type Manager struct {
mu sync.Mutex mu sync.Mutex
} }
func NewManager(wgIface lazyconn.WGIface) *Manager { func NewManager(wgIface WgInterface) *Manager {
m := &Manager{ m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1), OnActivityChan: make(chan peerid.ConnID, 1),
wgIface: wgIface, wgIface: wgIface,

View File

@@ -1,75 +0,0 @@
package inactivity
import (
"context"
"time"
peer "github.com/netbirdio/netbird/client/internal/peer/id"
)
const (
DefaultInactivityThreshold = 60 * time.Minute // idle after 1 hour inactivity
MinimumInactivityThreshold = 3 * time.Minute
)
type Monitor struct {
id peer.ConnID
timer *time.Timer
cancel context.CancelFunc
inactivityThreshold time.Duration
}
func NewInactivityMonitor(peerID peer.ConnID, threshold time.Duration) *Monitor {
i := &Monitor{
id: peerID,
timer: time.NewTimer(0),
inactivityThreshold: threshold,
}
i.timer.Stop()
return i
}
func (i *Monitor) Start(ctx context.Context, timeoutChan chan peer.ConnID) {
i.timer.Reset(i.inactivityThreshold)
defer i.timer.Stop()
ctx, i.cancel = context.WithCancel(ctx)
defer func() {
defer i.cancel()
select {
case <-i.timer.C:
default:
}
}()
select {
case <-i.timer.C:
select {
case timeoutChan <- i.id:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
func (i *Monitor) Stop() {
if i.cancel == nil {
return
}
i.cancel()
}
func (i *Monitor) PauseTimer() {
i.timer.Stop()
}
func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold)
}
func (i *Monitor) ResetMonitor(ctx context.Context, timeoutChan chan peer.ConnID) {
i.Stop()
go i.Start(ctx, timeoutChan)
}

View File

@@ -1,156 +0,0 @@
package inactivity
import (
"context"
"testing"
"time"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
type MocPeer struct {
}
func (m *MocPeer) ConnID() peerid.ConnID {
return peerid.ConnID(m)
}
func TestInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestReuseInactivityMonitor(t *testing.T) {
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), time.Second*2)
timeoutChan := make(chan peerid.ConnID)
for i := 2; i > 0; i-- {
exitChan := make(chan struct{})
testTimeoutCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
go func() {
defer close(exitChan)
im.Start(testTimeoutCtx, timeoutChan)
}()
select {
case <-timeoutChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
select {
case <-exitChan:
case <-testTimeoutCtx.Done():
t.Fatal("timeout")
}
testTimeoutCancel()
}
}
func TestStopInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*5)
defer testTimeoutCancel()
p := &MocPeer{}
im := NewInactivityMonitor(p.ConnID(), DefaultInactivityThreshold)
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(tCtx, timeoutChan)
}()
go func() {
time.Sleep(3 * time.Second)
im.Stop()
}()
select {
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-exitChan:
case <-tCtx.Done():
t.Fatal("timeout")
}
}
func TestPauseInactivityMonitor(t *testing.T) {
tCtx, testTimeoutCancel := context.WithTimeout(context.Background(), time.Second*10)
defer testTimeoutCancel()
p := &MocPeer{}
trashHold := time.Second * 3
im := NewInactivityMonitor(p.ConnID(), trashHold)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
timeoutChan := make(chan peerid.ConnID)
exitChan := make(chan struct{})
go func() {
defer close(exitChan)
im.Start(ctx, timeoutChan)
}()
time.Sleep(1 * time.Second) // grant time to start the monitor
im.PauseTimer()
// check to do not receive timeout
thresholdCtx, thresholdCancel := context.WithTimeout(context.Background(), trashHold+time.Second)
defer thresholdCancel()
select {
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
t.Fatal("unexpected timeout")
case <-thresholdCtx.Done():
// test ok
case <-tCtx.Done():
t.Fatal("test timed out")
}
// test reset timer
im.ResetTimer()
select {
case <-tCtx.Done():
t.Fatal("test timed out")
case <-exitChan:
t.Fatal("unexpected exit")
case <-timeoutChan:
// expected timeout
}
}

View File

@@ -0,0 +1,155 @@
package inactivity
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/monotime"
)
const (
checkInterval = 1 * time.Minute
DefaultInactivityThreshold = 15 * time.Minute
MinimumInactivityThreshold = 1 * time.Minute
)
type WgInterface interface {
LastActivities() map[string]monotime.Time
}
type Manager struct {
inactivePeersChan chan map[string]struct{}
iface WgInterface
interestedPeers map[string]*lazyconn.PeerConfig
inactivityThreshold time.Duration
}
func NewManager(iface WgInterface, configuredThreshold *time.Duration) *Manager {
inactivityThreshold, err := validateInactivityThreshold(configuredThreshold)
if err != nil {
inactivityThreshold = DefaultInactivityThreshold
log.Warnf("invalid inactivity threshold configured: %v, using default: %v", err, DefaultInactivityThreshold)
}
log.Infof("inactivity threshold configured: %v", inactivityThreshold)
return &Manager{
inactivePeersChan: make(chan map[string]struct{}, 1),
iface: iface,
interestedPeers: make(map[string]*lazyconn.PeerConfig),
inactivityThreshold: inactivityThreshold,
}
}
func (m *Manager) InactivePeersChan() chan map[string]struct{} {
if m == nil {
// return a nil channel that blocks forever
return nil
}
return m.inactivePeersChan
}
func (m *Manager) AddPeer(peerCfg *lazyconn.PeerConfig) {
if m == nil {
return
}
if _, exists := m.interestedPeers[peerCfg.PublicKey]; exists {
return
}
peerCfg.Log.Infof("adding peer to inactivity manager")
m.interestedPeers[peerCfg.PublicKey] = peerCfg
}
func (m *Manager) RemovePeer(peer string) {
if m == nil {
return
}
pi, ok := m.interestedPeers[peer]
if !ok {
return
}
pi.Log.Debugf("remove peer from inactivity manager")
delete(m.interestedPeers, peer)
}
func (m *Manager) Start(ctx context.Context) {
if m == nil {
return
}
ticker := newTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C():
idlePeers, err := m.checkStats()
if err != nil {
log.Errorf("error checking stats: %v", err)
return
}
if len(idlePeers) == 0 {
continue
}
m.notifyInactivePeers(ctx, idlePeers)
}
}
}
func (m *Manager) notifyInactivePeers(ctx context.Context, inactivePeers map[string]struct{}) {
select {
case m.inactivePeersChan <- inactivePeers:
case <-ctx.Done():
return
default:
return
}
}
func (m *Manager) checkStats() (map[string]struct{}, error) {
lastActivities := m.iface.LastActivities()
idlePeers := make(map[string]struct{})
checkTime := time.Now()
for peerID, peerCfg := range m.interestedPeers {
lastActive, ok := lastActivities[peerID]
if !ok {
// when peer is in connecting state
peerCfg.Log.Warnf("peer not found in wg stats")
continue
}
since := monotime.Since(lastActive)
if since > m.inactivityThreshold {
peerCfg.Log.Infof("peer is inactive since time: %s", checkTime.Add(-since).String())
idlePeers[peerID] = struct{}{}
}
}
return idlePeers, nil
}
func validateInactivityThreshold(configuredThreshold *time.Duration) (time.Duration, error) {
if configuredThreshold == nil {
return DefaultInactivityThreshold, nil
}
if *configuredThreshold < MinimumInactivityThreshold {
return 0, fmt.Errorf("configured inactivity threshold %v is too low, using %v", *configuredThreshold, MinimumInactivityThreshold)
}
return *configuredThreshold, nil
}

View File

@@ -0,0 +1,114 @@
package inactivity
import (
"context"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/monotime"
)
type mockWgInterface struct {
lastActivities map[string]monotime.Time
}
func (m *mockWgInterface) LastActivities() map[string]monotime.Time {
return m.lastActivities
}
func TestPeerTriggersInactivity(t *testing.T) {
peerID := "peer1"
wgMock := &mockWgInterface{
lastActivities: map[string]monotime.Time{
peerID: monotime.Time(int64(monotime.Now()) - int64(20*time.Minute)),
},
}
fakeTick := make(chan time.Time, 1)
newTicker = func(d time.Duration) Ticker {
return &fakeTickerMock{CChan: fakeTick}
}
peerLog := log.WithField("peer", peerID)
peerCfg := &lazyconn.PeerConfig{
PublicKey: peerID,
Log: peerLog,
}
manager := NewManager(wgMock, nil)
manager.AddPeer(peerCfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start the manager in a goroutine
go manager.Start(ctx)
// Send a tick to simulate time passage
fakeTick <- time.Now()
// Check if peer appears on inactivePeersChan
select {
case inactivePeers := <-manager.inactivePeersChan:
assert.Contains(t, inactivePeers, peerID, "expected peer to be marked inactive")
case <-time.After(1 * time.Second):
t.Fatal("expected inactivity event, but none received")
}
}
func TestPeerTriggersActivity(t *testing.T) {
peerID := "peer1"
wgMock := &mockWgInterface{
lastActivities: map[string]monotime.Time{
peerID: monotime.Time(int64(monotime.Now()) - int64(5*time.Minute)),
},
}
fakeTick := make(chan time.Time, 1)
newTicker = func(d time.Duration) Ticker {
return &fakeTickerMock{CChan: fakeTick}
}
peerLog := log.WithField("peer", peerID)
peerCfg := &lazyconn.PeerConfig{
PublicKey: peerID,
Log: peerLog,
}
manager := NewManager(wgMock, nil)
manager.AddPeer(peerCfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Start the manager in a goroutine
go manager.Start(ctx)
// Send a tick to simulate time passage
fakeTick <- time.Now()
// Check if peer appears on inactivePeersChan
select {
case <-manager.inactivePeersChan:
t.Fatal("expected inactive peer to be marked inactive")
case <-time.After(1 * time.Second):
// No inactivity event should be received
}
}
// fakeTickerMock implements Ticker interface for testing
type fakeTickerMock struct {
CChan chan time.Time
}
func (f *fakeTickerMock) C() <-chan time.Time {
return f.CChan
}
func (f *fakeTickerMock) Stop() {}

View File

@@ -0,0 +1,24 @@
package inactivity
import "time"
var newTicker = func(d time.Duration) Ticker {
return &realTicker{t: time.NewTicker(d)}
}
type Ticker interface {
C() <-chan time.Time
Stop()
}
type realTicker struct {
t *time.Ticker
}
func (r *realTicker) C() <-chan time.Time {
return r.t.C
}
func (r *realTicker) Stop() {
r.t.Stop()
}

View File

@@ -11,7 +11,6 @@ import (
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity" "github.com/netbirdio/netbird/client/internal/lazyconn/activity"
"github.com/netbirdio/netbird/client/internal/lazyconn/inactivity" "github.com/netbirdio/netbird/client/internal/lazyconn/inactivity"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
peerid "github.com/netbirdio/netbird/client/internal/peer/id" peerid "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -43,60 +42,46 @@ type Config struct {
type Manager struct { type Manager struct {
engineCtx context.Context engineCtx context.Context
peerStore *peerstore.Store peerStore *peerstore.Store
connStateDispatcher *dispatcher.ConnectionDispatcher
inactivityThreshold time.Duration inactivityThreshold time.Duration
connStateListener *dispatcher.ConnectionListener
managedPeers map[string]*lazyconn.PeerConfig managedPeers map[string]*lazyconn.PeerConfig
managedPeersByConnID map[peerid.ConnID]*managedPeer managedPeersByConnID map[peerid.ConnID]*managedPeer
excludes map[string]lazyconn.PeerConfig excludes map[string]lazyconn.PeerConfig
managedPeersMu sync.Mutex managedPeersMu sync.Mutex
activityManager *activity.Manager activityManager *activity.Manager
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor inactivityManager *inactivity.Manager
// Route HA group management // Route HA group management
// If any peer in the same HA group is active, all peers in that group should prevent going idle
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
routesMu sync.RWMutex routesMu sync.RWMutex
onInactive chan peerid.ConnID
} }
// NewManager creates a new lazy connection manager // NewManager creates a new lazy connection manager
// engineCtx is the context for creating peer Connection // engineCtx is the context for creating peer Connection
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager { func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface) *Manager {
log.Infof("setup lazy connection service") log.Infof("setup lazy connection service")
m := &Manager{ m := &Manager{
engineCtx: engineCtx, engineCtx: engineCtx,
peerStore: peerStore, peerStore: peerStore,
connStateDispatcher: connStateDispatcher,
inactivityThreshold: inactivity.DefaultInactivityThreshold, inactivityThreshold: inactivity.DefaultInactivityThreshold,
managedPeers: make(map[string]*lazyconn.PeerConfig), managedPeers: make(map[string]*lazyconn.PeerConfig),
managedPeersByConnID: make(map[peerid.ConnID]*managedPeer), managedPeersByConnID: make(map[peerid.ConnID]*managedPeer),
excludes: make(map[string]lazyconn.PeerConfig), excludes: make(map[string]lazyconn.PeerConfig),
activityManager: activity.NewManager(wgIface), activityManager: activity.NewManager(wgIface),
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
peerToHAGroups: make(map[string][]route.HAUniqueID), peerToHAGroups: make(map[string][]route.HAUniqueID),
haGroupToPeers: make(map[route.HAUniqueID][]string), haGroupToPeers: make(map[route.HAUniqueID][]string),
onInactive: make(chan peerid.ConnID),
} }
if config.InactivityThreshold != nil { if wgIface.IsUserspaceBind() {
if *config.InactivityThreshold >= inactivity.MinimumInactivityThreshold { m.inactivityManager = inactivity.NewManager(wgIface, config.InactivityThreshold)
m.inactivityThreshold = *config.InactivityThreshold } else {
} else { log.Warnf("inactivity manager not supported for kernel mode, wait for remote peer to close the connection")
log.Warnf("inactivity threshold is too low, using %v", m.inactivityThreshold)
}
} }
m.connStateListener = &dispatcher.ConnectionListener{
OnConnected: m.onPeerConnected,
OnDisconnected: m.onPeerDisconnected,
}
connStateDispatcher.AddListener(m.connStateListener)
return m return m
} }
@@ -131,24 +116,28 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
} }
} }
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes", log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes", len(m.haGroupToPeers), len(m.peerToHAGroups))
len(m.haGroupToPeers), len(m.peerToHAGroups))
} }
// Start starts the manager and listens for peer activity and inactivity events // Start starts the manager and listens for peer activity and inactivity events
func (m *Manager) Start(ctx context.Context) { func (m *Manager) Start(ctx context.Context) {
defer m.close() defer m.close()
if m.inactivityManager != nil {
go m.inactivityManager.Start(ctx)
}
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case peerConnID := <-m.activityManager.OnActivityChan: case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ctx, peerConnID) m.onPeerActivity(peerConnID)
case peerConnID := <-m.onInactive: case peerIDs := <-m.inactivityManager.InactivePeersChan():
m.onPeerInactivityTimedOut(ctx, peerConnID) m.onPeerInactivityTimedOut(peerIDs)
} }
} }
} }
// ExcludePeer marks peers for a permanent connection // ExcludePeer marks peers for a permanent connection
@@ -156,7 +145,7 @@ func (m *Manager) Start(ctx context.Context) {
// Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In // Adds them back to the managed list and start the inactivity listener if they are removed from the exclude list. In
// this case, we suppose that the connection status is connected or connecting. // this case, we suppose that the connection status is connected or connecting.
// If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function // If the peer is not exists yet in the managed list then the responsibility is the upper layer to call the AddPeer function
func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerConfig) []string { func (m *Manager) ExcludePeer(peerConfigs []lazyconn.PeerConfig) []string {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -187,7 +176,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo
peerCfg.Log.Infof("peer removed from lazy connection exclude list") peerCfg.Log.Infof("peer removed from lazy connection exclude list")
if err := m.addActivePeer(ctx, peerCfg); err != nil { if err := m.addActivePeer(&peerCfg); err != nil {
log.Errorf("failed to add peer to lazy connection manager: %s", err) log.Errorf("failed to add peer to lazy connection manager: %s", err)
continue continue
} }
@@ -197,7 +186,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo
return added return added
} }
func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (bool, error) { func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -217,9 +206,6 @@ func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (boo
return false, err return false, err
} }
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold)
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{ m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg, peerCfg: &peerCfg,
@@ -229,7 +215,7 @@ func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (boo
// Check if this peer should be activated because its HA group peers are active // Check if this peer should be activated because its HA group peers are active
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok { if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group) peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
m.activateNewPeerInActiveGroup(ctx, peerCfg) m.activateNewPeerInActiveGroup(peerCfg)
} }
return false, nil return false, nil
@@ -237,7 +223,7 @@ func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (boo
// AddActivePeers adds a list of peers to the lazy connection manager // AddActivePeers adds a list of peers to the lazy connection manager
// suppose these peers was in connected or in connecting states // suppose these peers was in connected or in connecting states
func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerConfig) error { func (m *Manager) AddActivePeers(peerCfg []lazyconn.PeerConfig) error {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -247,7 +233,7 @@ func (m *Manager) AddActivePeers(ctx context.Context, peerCfg []lazyconn.PeerCon
continue continue
} }
if err := m.addActivePeer(ctx, cfg); err != nil { if err := m.addActivePeer(&cfg); err != nil {
cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err) cfg.Log.Errorf("failed to add peer to lazy connection manager: %v", err)
return err return err
} }
@@ -264,7 +250,7 @@ func (m *Manager) RemovePeer(peerID string) {
// ActivatePeer activates a peer connection when a signal message is received // ActivatePeer activates a peer connection when a signal message is received
// Also activates all peers in the same HA groups as this peer // Also activates all peers in the same HA groups as this peer
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) { func (m *Manager) ActivatePeer(peerID string) (found bool) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
cfg, mp := m.getPeerForActivation(peerID) cfg, mp := m.getPeerForActivation(peerID)
@@ -272,15 +258,43 @@ func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool)
return false return false
} }
if !m.activateSinglePeer(ctx, cfg, mp) { cfg.Log.Infof("activate peer from inactive state by remote signal message")
if !m.activateSinglePeer(cfg, mp) {
return false return false
} }
m.activateHAGroupPeers(ctx, peerID) m.activateHAGroupPeers(cfg)
return true return true
} }
func (m *Manager) DeactivatePeer(peerID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
mp.expectedWatcher = watcherActivity
m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
return
}
}
// getPeerForActivation checks if a peer can be activated and returns the necessary structs // getPeerForActivation checks if a peer can be activated and returns the necessary structs
// Returns nil values if the peer should be skipped // Returns nil values if the peer should be skipped
func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) { func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) {
@@ -302,41 +316,36 @@ func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *ma
return cfg, mp return cfg, mp
} }
// activateSinglePeer activates a single peer (internal method) // activateSinglePeer activates a single peer
func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConfig, mp *managedPeer) bool { // return true if the peer was activated, false if it was already active
mp.expectedWatcher = watcherInactivity func (m *Manager) activateSinglePeer(cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
if mp.expectedWatcher == watcherInactivity {
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
im, ok := m.inactivityMonitors[cfg.PeerConnID]
if !ok {
cfg.Log.Errorf("inactivity monitor not found for peer")
return false return false
} }
cfg.Log.Infof("starting inactivity monitor") mp.expectedWatcher = watcherInactivity
go im.Start(ctx, m.onInactive) m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
m.inactivityManager.AddPeer(cfg)
return true return true
} }
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to // activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) { func (m *Manager) activateHAGroupPeers(triggeredPeerCfg *lazyconn.PeerConfig) {
var peersToActivate []string var peersToActivate []string
m.routesMu.RLock() m.routesMu.RLock()
haGroups := m.peerToHAGroups[triggerPeerID] haGroups := m.peerToHAGroups[triggeredPeerCfg.PublicKey]
if len(haGroups) == 0 { if len(haGroups) == 0 {
m.routesMu.RUnlock() m.routesMu.RUnlock()
log.Debugf("peer %s is not part of any HA groups", triggerPeerID) triggeredPeerCfg.Log.Debugf("peer is not part of any HA groups")
return return
} }
for _, haGroup := range haGroups { for _, haGroup := range haGroups {
peers := m.haGroupToPeers[haGroup] peers := m.haGroupToPeers[haGroup]
for _, peerID := range peers { for _, peerID := range peers {
if peerID != triggerPeerID { if peerID != triggeredPeerCfg.PublicKey {
peersToActivate = append(peersToActivate, peerID) peersToActivate = append(peersToActivate, peerID)
} }
} }
@@ -350,16 +359,16 @@ func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string
continue continue
} }
if m.activateSinglePeer(ctx, cfg, mp) { if m.activateSinglePeer(cfg, mp) {
activatedCount++ activatedCount++
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggerPeerID) cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggeredPeerCfg.PublicKey)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey) m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
} }
} }
if activatedCount > 0 { if activatedCount > 0 {
log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)", log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)",
activatedCount, triggerPeerID, haGroups) activatedCount, triggeredPeerCfg.PublicKey, haGroups)
} }
} }
@@ -394,13 +403,13 @@ func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool)
} }
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group // activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazyconn.PeerConfig) { func (m *Manager) activateNewPeerInActiveGroup(peerCfg lazyconn.PeerConfig) {
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID] mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok { if !ok {
return return
} }
if !m.activateSinglePeer(ctx, &peerCfg, mp) { if !m.activateSinglePeer(&peerCfg, mp) {
return return
} }
@@ -408,23 +417,19 @@ func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazy
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey) m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
} }
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error { func (m *Manager) addActivePeer(peerCfg *lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok { if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed") peerCfg.Log.Warnf("peer already managed")
return nil return nil
} }
im := inactivity.NewInactivityMonitor(peerCfg.PeerConnID, m.inactivityThreshold) m.managedPeers[peerCfg.PublicKey] = peerCfg
m.inactivityMonitors[peerCfg.PeerConnID] = im
m.managedPeers[peerCfg.PublicKey] = &peerCfg
m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{ m.managedPeersByConnID[peerCfg.PeerConnID] = &managedPeer{
peerCfg: &peerCfg, peerCfg: peerCfg,
expectedWatcher: watcherInactivity, expectedWatcher: watcherInactivity,
} }
peerCfg.Log.Infof("starting inactivity monitor on peer that has been removed from exclude list") m.inactivityManager.AddPeer(peerCfg)
go im.Start(ctx, m.onInactive)
return nil return nil
} }
@@ -436,12 +441,7 @@ func (m *Manager) removePeer(peerID string) {
cfg.Log.Infof("removing lazy peer") cfg.Log.Infof("removing lazy peer")
if im, ok := m.inactivityMonitors[cfg.PeerConnID]; ok { m.inactivityManager.RemovePeer(cfg.PublicKey)
im.Stop()
delete(m.inactivityMonitors, cfg.PeerConnID)
cfg.Log.Debugf("inactivity monitor stopped")
}
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID) m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
delete(m.managedPeers, peerID) delete(m.managedPeers, peerID)
delete(m.managedPeersByConnID, cfg.PeerConnID) delete(m.managedPeersByConnID, cfg.PeerConnID)
@@ -451,12 +451,8 @@ func (m *Manager) close() {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
m.connStateDispatcher.RemoveListener(m.connStateListener)
m.activityManager.Close() m.activityManager.Close()
for _, iw := range m.inactivityMonitors {
iw.Stop()
}
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
m.managedPeers = make(map[string]*lazyconn.PeerConfig) m.managedPeers = make(map[string]*lazyconn.PeerConfig)
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer) m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
@@ -470,7 +466,7 @@ func (m *Manager) close() {
} }
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements // shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
func (m *Manager) shouldDeferIdleForHA(peerID string) bool { func (m *Manager) shouldDeferIdleForHA(inactivePeers map[string]struct{}, peerID string) bool {
m.routesMu.RLock() m.routesMu.RLock()
defer m.routesMu.RUnlock() defer m.routesMu.RUnlock()
@@ -480,38 +476,45 @@ func (m *Manager) shouldDeferIdleForHA(peerID string) bool {
} }
for _, haGroup := range haGroups { for _, haGroup := range haGroups {
groupPeers := m.haGroupToPeers[haGroup] if active := m.checkHaGroupActivity(haGroup, peerID, inactivePeers); active {
return true
for _, groupPeerID := range groupPeers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
continue
}
if groupMp.expectedWatcher != watcherInactivity {
continue
}
// Other member is still connected, defer idle
if peer, ok := m.peerStore.PeerConn(groupPeerID); ok && peer.IsConnected() {
return true
}
} }
} }
return false return false
} }
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) { func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string, inactivePeers map[string]struct{}) bool {
groupPeers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range groupPeers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
continue
}
if groupMp.expectedWatcher != watcherInactivity {
continue
}
// If any peer in the group is active, do defer idle
if _, isInactive := inactivePeers[groupPeerID]; !isInactive {
return true
}
}
return false
}
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -528,100 +531,56 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID)
mp.peerCfg.Log.Infof("detected peer activity") mp.peerCfg.Log.Infof("detected peer activity")
if !m.activateSinglePeer(ctx, mp.peerCfg, mp) { if !m.activateSinglePeer(mp.peerCfg, mp) {
return return
} }
m.activateHAGroupPeers(ctx, mp.peerCfg.PublicKey) m.activateHAGroupPeers(mp.peerCfg)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey) m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
} }
func (m *Manager) onPeerInactivityTimedOut(ctx context.Context, peerConnID peerid.ConnID) { func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID] for peerID := range peerIDs {
if !ok { peerCfg, ok := m.managedPeers[peerID]
log.Errorf("peer not found by id: %v", peerConnID) if !ok {
return log.Errorf("peer not found by peerId: %v", peerID)
} continue
if mp.expectedWatcher != watcherInactivity {
mp.peerCfg.Log.Warnf("ignore inactivity event")
return
}
if m.shouldDeferIdleForHA(mp.peerCfg.PublicKey) {
iw, ok := m.inactivityMonitors[peerConnID]
if ok {
mp.peerCfg.Log.Debugf("resetting inactivity timer due to HA group requirements")
iw.ResetMonitor(ctx, m.onInactive)
} else {
mp.peerCfg.Log.Errorf("inactivity monitor not found for HA defer reset")
} }
return
}
mp.peerCfg.Log.Infof("connection timed out") mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok {
log.Errorf("peer not found by conn id: %v", peerCfg.PeerConnID)
continue
}
// this is blocking operation, potentially can be optimized if mp.expectedWatcher != watcherInactivity {
m.peerStore.PeerConnClose(mp.peerCfg.PublicKey) mp.peerCfg.Log.Warnf("ignore inactivity event")
continue
}
mp.peerCfg.Log.Infof("start activity monitor") if m.shouldDeferIdleForHA(peerIDs, mp.peerCfg.PublicKey) {
mp.peerCfg.Log.Infof("defer inactivity due to active HA group peers")
continue
}
mp.expectedWatcher = watcherActivity mp.peerCfg.Log.Infof("connection timed out")
// just in case free up // this is blocking operation, potentially can be optimized
m.inactivityMonitors[peerConnID].PauseTimer() m.peerStore.PeerConnIdle(mp.peerCfg.PublicKey)
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil { mp.expectedWatcher = watcherActivity
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
return m.inactivityManager.RemovePeer(mp.peerCfg.PublicKey)
mp.peerCfg.Log.Infof("start activity monitor")
if err := m.activityManager.MonitorPeerActivity(*mp.peerCfg); err != nil {
mp.peerCfg.Log.Errorf("failed to create activity monitor: %v", err)
continue
}
} }
} }
func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
mp.peerCfg.Log.Warnf("inactivity monitor not found for peer")
return
}
mp.peerCfg.Log.Infof("peer connected, pausing inactivity monitor while connection is not disconnected")
iw.PauseTimer()
}
func (m *Manager) onPeerDisconnected(peerConnID peerid.ConnID) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
if !ok {
return
}
if mp.expectedWatcher != watcherInactivity {
return
}
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok {
return
}
mp.peerCfg.Log.Infof("reset inactivity monitor timer")
iw.ResetTimer()
}

View File

@@ -6,9 +6,13 @@ import (
"time" "time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/monotime"
) )
type WGIface interface { type WGIface interface {
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
LastActivities() map[string]monotime.Time
} }

View File

@@ -148,7 +148,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
) )
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels) loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil { if err != nil {
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String()) log.Errorf("failed registering peer %v", err)
return nil, err return nil, err
} }

View File

@@ -19,7 +19,7 @@ type mockIFaceMapper struct {
} }
func (m *mockIFaceMapper) Name() string { func (m *mockIFaceMapper) Name() string {
return "wt0" return "nb0"
} }
func (m *mockIFaceMapper) Address() wgaddr.Address { func (m *mockIFaceMapper) Address() wgaddr.Address {

View File

@@ -26,7 +26,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
relayClient "github.com/netbirdio/netbird/relay/client" relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
) )
@@ -106,10 +105,6 @@ type Conn struct {
workerRelay *WorkerRelay workerRelay *WorkerRelay
wgWatcherWg sync.WaitGroup wgWatcherWg sync.WaitGroup
connIDRelay nbnet.ConnectionID
connIDICE nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice // used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
rosenpassRemoteKey []byte rosenpassRemoteKey []byte
@@ -117,10 +112,9 @@ type Conn struct {
wgProxyRelay wgproxy.Proxy wgProxyRelay wgproxy.Proxy
handshaker *Handshaker handshaker *Handshaker
guard *guard.Guard guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup semaphore *semaphoregroup.SemaphoreGroup
peerConnDispatcher *dispatcher.ConnectionDispatcher wg sync.WaitGroup
wg sync.WaitGroup
// debug purpose // debug purpose
dumpState *stateDump dumpState *stateDump
@@ -136,18 +130,17 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
connLog := log.WithField("peer", config.Key) connLog := log.WithField("peer", config.Key)
var conn = &Conn{ var conn = &Conn{
Log: connLog, Log: connLog,
config: config, config: config,
statusRecorder: services.StatusRecorder, statusRecorder: services.StatusRecorder,
signaler: services.Signaler, signaler: services.Signaler,
iFaceDiscover: services.IFaceDiscover, iFaceDiscover: services.IFaceDiscover,
relayManager: services.RelayManager, relayManager: services.RelayManager,
srWatcher: services.SrWatcher, srWatcher: services.SrWatcher,
semaphore: services.Semaphore, semaphore: services.Semaphore,
peerConnDispatcher: services.PeerConnDispatcher, statusRelay: worker.NewAtomicStatus(),
statusRelay: worker.NewAtomicStatus(), statusICE: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(), dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
} }
return conn, nil return conn, nil
@@ -169,7 +162,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx) conn.ctx, conn.ctxCancel = context.WithCancel(engineCtx)
conn.workerRelay = NewWorkerRelay(conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState) conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
@@ -226,7 +219,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
} }
// Close closes this peer Conn issuing a close event to the Conn closeCh // Close closes this peer Conn issuing a close event to the Conn closeCh
func (conn *Conn) Close() { func (conn *Conn) Close(signalToRemote bool) {
conn.mu.Lock() conn.mu.Lock()
defer conn.wgWatcherWg.Wait() defer conn.wgWatcherWg.Wait()
defer conn.mu.Unlock() defer conn.mu.Unlock()
@@ -236,6 +229,12 @@ func (conn *Conn) Close() {
return return
} }
if signalToRemote {
if err := conn.signaler.SignalIdle(conn.config.Key); err != nil {
conn.Log.Errorf("failed to signal idle state to peer: %v", err)
}
}
conn.Log.Infof("close peer connection") conn.Log.Infof("close peer connection")
conn.ctxCancel() conn.ctxCancel()
@@ -263,8 +262,6 @@ func (conn *Conn) Close() {
conn.Log.Errorf("failed to remove wg endpoint: %v", err) conn.Log.Errorf("failed to remove wg endpoint: %v", err)
} }
conn.freeUpConnID()
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
conn.onDisconnected(conn.config.WgConfig.RemoteKey) conn.onDisconnected(conn.config.WgConfig.RemoteKey)
} }
@@ -289,13 +286,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
conn.workerICE.OnRemoteCandidate(candidate, haRoutes) conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
} }
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) { func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
conn.onConnected = handler conn.onConnected = handler
@@ -383,10 +373,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
ep = directEp ep = directEp
} }
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
conn.workerRelay.DisableWgWatcher() conn.workerRelay.DisableWgWatcher()
// todo consider to run conn.wgWatcherWg.Wait() here // todo consider to run conn.wgWatcherWg.Wait() here
@@ -404,15 +390,10 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
} }
wgConfigWorkaround() wgConfigWorkaround()
oldState := conn.currentConnPriority
conn.currentConnPriority = priority conn.currentConnPriority = priority
conn.statusICE.SetConnected() conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo) conn.updateIceState(iceConnInfo)
conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
if oldState == conntype.None {
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
}
} }
func (conn *Conn) onICEStateDisconnected() { func (conn *Conn) onICEStateDisconnected() {
@@ -450,7 +431,6 @@ func (conn *Conn) onICEStateDisconnected() {
} else { } else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
conn.currentConnPriority = conntype.None conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
} }
changed := conn.statusICE.Get() != worker.StatusDisconnected changed := conn.statusICE.Get() != worker.StatusDisconnected
@@ -491,6 +471,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.Log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
wgProxy.SetDisconnectListener(conn.onRelayDisconnected)
conn.dumpState.NewLocalProxy() conn.dumpState.NewLocalProxy()
conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) conn.Log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
@@ -503,10 +485,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return return
} }
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
conn.Log.Errorf("Before add peer hook failed: %v", err)
}
wgProxy.Work() wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil { if err := wgProxy.CloseConn(); err != nil {
@@ -530,7 +508,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
conn.Log.Infof("start to communicate with peer via relay") conn.Log.Infof("start to communicate with peer via relay")
conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
conn.peerConnDispatcher.NotifyConnected(conn.ConnID())
} }
func (conn *Conn) onRelayDisconnected() { func (conn *Conn) onRelayDisconnected() {
@@ -545,11 +522,7 @@ func (conn *Conn) onRelayDisconnected() {
if conn.currentConnPriority == conntype.Relay { if conn.currentConnPriority == conntype.Relay {
conn.Log.Debugf("clean up WireGuard config") conn.Log.Debugf("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
conn.currentConnPriority = conntype.None conn.currentConnPriority = conntype.None
conn.peerConnDispatcher.NotifyDisconnected(conn.ConnID())
} }
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
@@ -712,36 +685,6 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
return true return true
} }
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
conn.connIDICE = nbnet.GenerateConnID()
for _, hook := range conn.beforeAddPeerHooks {
if err := hook(conn.connIDICE, ip); err != nil {
return err
}
}
return nil
}
func (conn *Conn) freeUpConnID() {
if conn.connIDRelay != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDRelay); err != nil {
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDRelay = ""
}
if conn.connIDICE != "" {
for _, hook := range conn.afterRemovePeerHooks {
if err := hook(conn.connIDICE); err != nil {
conn.Log.Errorf("After remove peer hook failed: %v", err)
}
}
conn.connIDICE = ""
}
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
conn.Log.Debugf("setup proxied WireGuard connection") conn.Log.Debugf("setup proxied WireGuard connection")
udpAddr := &net.UDPAddr{ udpAddr := &net.UDPAddr{

View File

@@ -68,3 +68,13 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
return nil return nil
} }
func (s *Signaler) SignalIdle(remoteKey string) error {
return s.signal.Send(&sProto.Message{
Key: s.wgPrivateKey.PublicKey().String(),
RemoteKey: remoteKey,
Body: &sProto.Body{
Type: sProto.Body_GO_IDLE,
},
})
}

View File

@@ -19,6 +19,7 @@ type RelayConnInfo struct {
} }
type WorkerRelay struct { type WorkerRelay struct {
peerCtx context.Context
log *log.Entry log *log.Entry
isController bool isController bool
config ConnConfig config ConnConfig
@@ -33,8 +34,9 @@ type WorkerRelay struct {
wgWatcher *WGWatcher wgWatcher *WGWatcher
} }
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay { func NewWorkerRelay(ctx context.Context, log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
peerCtx: ctx,
log: log, log: log,
isController: ctrl, isController: ctrl,
config: config, config: config,
@@ -62,7 +64,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key) relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
if err != nil { if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) { if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Debugf("handled offer by reusing existing relay connection") w.log.Debugf("handled offer by reusing existing relay connection")

View File

@@ -95,6 +95,17 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
} }
func (s *Store) PeerConnIdle(pubKey string) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
p.Close(true)
}
func (s *Store) PeerConnClose(pubKey string) { func (s *Store) PeerConnClose(pubKey string) {
s.peerConnsMu.RLock() s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock() defer s.peerConnsMu.RUnlock()
@@ -103,7 +114,7 @@ func (s *Store) PeerConnClose(pubKey string) {
if !ok { if !ok {
return return
} }
p.Close() p.Close(false)
} }
func (s *Store) PeersPubKey() []string { func (s *Store) PeersPubKey() []string {

View File

@@ -10,11 +10,10 @@ import (
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" "github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -553,41 +552,16 @@ func (w *Watcher) Stop() {
w.currentChosenStatus = nil w.currentChosenStatus = nil
} }
func HandlerFromRoute( func HandlerFromRoute(params common.HandlerParams) RouteHandler {
rt *route.Route, switch handlerType(params.Route, params.UseNewDNSRoute) {
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsRouterInteval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.WGIface,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
) RouteHandler {
switch handlerType(rt, useNewDNSRoute) {
case handlerTypeDnsInterceptor: case handlerTypeDnsInterceptor:
return dnsinterceptor.New( return dnsinterceptor.New(params)
rt,
routeRefCounter,
allowedIPsRefCounter,
statusRecorder,
dnsServer,
wgInterface,
peerStore,
)
case handlerTypeDynamic: case handlerTypeDynamic:
dns := nbdns.NewServiceViaMemory(wgInterface) dns := nbdns.NewServiceViaMemory(params.WgInterface)
return dynamic.NewRoute( dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
rt, return dynamic.NewRoute(params, dnsAddr)
routeRefCounter,
allowedIPsRefCounter,
dnsRouterInteval,
statusRecorder,
wgInterface,
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
)
default: default:
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) return static.NewRoute(params)
} }
} }

View File

@@ -7,12 +7,12 @@ import (
"time" "time"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
func TestGetBestrouteFromStatuses(t *testing.T) { func TestGetBestrouteFromStatuses(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
statuses map[route.ID]routerPeerStatus statuses map[route.ID]routerPeerStatus
@@ -811,9 +811,12 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
currentRoute = tc.existingRoutes[tc.currentRoute] currentRoute = tc.existingRoutes[tc.currentRoute]
} }
params := common.HandlerParams{
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
}
// create new clientNetwork // create new clientNetwork
client := &Watcher{ client := &Watcher{
handler: static.NewRoute(&route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, nil, nil), handler: static.NewRoute(params),
routes: tc.existingRoutes, routes: tc.existingRoutes,
currentChosen: currentRoute, currentChosen: currentRoute,
} }

View File

@@ -0,0 +1,28 @@
package common
import (
"time"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route"
)
type HandlerParams struct {
Route *route.Route
RouteRefCounter *refcounter.RouteRefCounter
AllowedIPsRefCounter *refcounter.AllowedIPsRefCounter
DnsRouterInterval time.Duration
StatusRecorder *peer.Status
WgInterface iface.WGIface
DnsServer dns.Server
PeerStore *peerstore.Store
UseNewDNSRoute bool
Firewall manager.Manager
FakeIPManager *fakeip.Manager
}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"runtime"
"strings" "strings"
"sync" "sync"
@@ -12,11 +13,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -24,6 +28,11 @@ import (
type domainMap map[domain.Domain][]netip.Prefix type domainMap map[domain.Domain][]netip.Prefix
type internalDNATer interface {
RemoveInternalDNATMapping(netip.Addr) error
AddInternalDNATMapping(netip.Addr, netip.Addr) error
}
type wgInterface interface { type wgInterface interface {
Name() string Name() string
Address() wgaddr.Address Address() wgaddr.Address
@@ -40,26 +49,22 @@ type DnsInterceptor struct {
interceptedDomains domainMap interceptedDomains domainMap
wgInterface wgInterface wgInterface wgInterface
peerStore *peerstore.Store peerStore *peerstore.Store
firewall firewall.Manager
fakeIPManager *fakeip.Manager
} }
func New( func New(params common.HandlerParams) *DnsInterceptor {
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status,
dnsServer nbdns.Server,
wgInterface wgInterface,
peerStore *peerstore.Store,
) *DnsInterceptor {
return &DnsInterceptor{ return &DnsInterceptor{
route: rt, route: params.Route,
routeRefCounter: routeRefCounter, routeRefCounter: params.RouteRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter, allowedIPsRefcounter: params.AllowedIPsRefCounter,
statusRecorder: statusRecorder, statusRecorder: params.StatusRecorder,
dnsServer: dnsServer, dnsServer: params.DnsServer,
wgInterface: wgInterface, wgInterface: params.WgInterface,
peerStore: params.PeerStore,
firewall: params.Firewall,
fakeIPManager: params.FakeIPManager,
interceptedDomains: make(domainMap), interceptedDomains: make(domainMap),
peerStore: peerStore,
} }
} }
@@ -78,9 +83,13 @@ func (d *DnsInterceptor) RemoveRoute() error {
var merr *multierror.Error var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains { for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes { for _, prefix := range prefixes {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil { // Routes should use fake IPs
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err)) routePrefix := d.transformRealToFakePrefix(prefix)
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", routePrefix, err))
} }
// AllowedIPs should use real IPs
if d.currentPeerKey != "" { if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
@@ -88,8 +97,10 @@ func (d *DnsInterceptor) RemoveRoute() error {
} }
} }
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", ")) log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
} }
d.cleanupDNATMappings()
for _, domain := range d.route.Domains { for _, domain := range d.route.Domains {
d.statusRecorder.DeleteResolvedDomainsStates(domain) d.statusRecorder.DeleteResolvedDomainsStates(domain)
} }
@@ -102,6 +113,68 @@ func (d *DnsInterceptor) RemoveRoute() error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// transformRealToFakePrefix returns fake IP prefix for routes (if DNAT enabled)
func (d *DnsInterceptor) transformRealToFakePrefix(realPrefix netip.Prefix) netip.Prefix {
if _, hasDNAT := d.internalDnatFw(); !hasDNAT {
return realPrefix
}
if fakeIP, ok := d.fakeIPManager.GetFakeIP(realPrefix.Addr()); ok {
return netip.PrefixFrom(fakeIP, realPrefix.Bits())
}
return realPrefix
}
// addAllowedIPForPrefix handles the AllowedIPs logic for a single prefix (uses real IPs)
func (d *DnsInterceptor) addAllowedIPForPrefix(realPrefix netip.Prefix, peerKey string, domain domain.Domain) error {
// AllowedIPs always use real IPs
ref, err := d.allowedIPsRefcounter.Increment(realPrefix, peerKey)
if err != nil {
return fmt.Errorf("add allowed IP %s: %v", realPrefix, err)
}
if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
realPrefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
return nil
}
// addRouteAndAllowedIP handles both route and AllowedIPs addition for a prefix
func (d *DnsInterceptor) addRouteAndAllowedIP(realPrefix netip.Prefix, domain domain.Domain) error {
// Routes use fake IPs (so traffic to fake IPs gets routed to interface)
routePrefix := d.transformRealToFakePrefix(realPrefix)
if _, err := d.routeRefCounter.Increment(routePrefix, struct{}{}); err != nil {
return fmt.Errorf("add route for IP %s: %v", routePrefix, err)
}
// Add to AllowedIPs if we have a current peer (uses real IPs)
if d.currentPeerKey == "" {
return nil
}
return d.addAllowedIPForPrefix(realPrefix, d.currentPeerKey, domain)
}
// removeAllowedIP handles AllowedIPs removal for a prefix (uses real IPs)
func (d *DnsInterceptor) removeAllowedIP(realPrefix netip.Prefix) error {
if d.currentPeerKey == "" {
return nil
}
// AllowedIPs use real IPs
if _, err := d.allowedIPsRefcounter.Decrement(realPrefix); err != nil {
return fmt.Errorf("remove allowed IP %s: %v", realPrefix, err)
}
return nil
}
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error { func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
@@ -109,14 +182,9 @@ func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
var merr *multierror.Error var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains { for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes { for _, prefix := range prefixes {
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil { // AllowedIPs use real IPs
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) if err := d.addAllowedIPForPrefix(prefix, peerKey, domain); err != nil {
} else if ref.Count > 1 && ref.Out != peerKey { merr = multierror.Append(merr, err)
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
} }
} }
} }
@@ -132,6 +200,7 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
var merr *multierror.Error var merr *multierror.Error
for _, prefixes := range d.interceptedDomains { for _, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes { for _, prefix := range prefixes {
// AllowedIPs use real IPs
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
} }
@@ -287,6 +356,8 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil { if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err) log.Errorf("failed to update domain prefixes: %v", err)
} }
d.replaceIPsInDNSResponse(r, newPrefixes)
} }
} }
@@ -297,6 +368,22 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
return nil return nil
} }
// logPrefixChanges handles the logging for prefix changes
func (d *DnsInterceptor) logPrefixChanges(resolvedDomain, originalDomain domain.Domain, toAdd, toRemove []netip.Prefix) {
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error { func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
@@ -305,70 +392,163 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes) toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
var merr *multierror.Error var merr *multierror.Error
var dnatMappings map[netip.Addr]netip.Addr
// Handle DNAT mappings for new prefixes
if _, hasDNAT := d.internalDnatFw(); hasDNAT {
dnatMappings = make(map[netip.Addr]netip.Addr)
for _, prefix := range toAdd {
realIP := prefix.Addr()
if fakeIP, err := d.fakeIPManager.AllocateFakeIP(realIP); err == nil {
dnatMappings[fakeIP] = realIP
log.Tracef("allocated fake IP %s for real IP %s", fakeIP, realIP)
} else {
log.Errorf("Failed to allocate fake IP for %s: %v", realIP, err)
}
}
}
// Add new prefixes // Add new prefixes
for _, prefix := range toAdd { for _, prefix := range toAdd {
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil { if err := d.addRouteAndAllowedIP(prefix, resolvedDomain); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err)) merr = multierror.Append(merr, err)
continue
}
if d.currentPeerKey == "" {
continue
}
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
resolvedDomain.SafeString(),
ref.Out,
)
} }
} }
d.addDNATMappings(dnatMappings)
if !d.route.KeepRoute { if !d.route.KeepRoute {
// Remove old prefixes // Remove old prefixes
for _, prefix := range toRemove { for _, prefix := range toRemove {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil { // Routes use fake IPs
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err)) routePrefix := d.transformRealToFakePrefix(prefix)
if _, err := d.routeRefCounter.Decrement(routePrefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", routePrefix, err))
} }
if d.currentPeerKey != "" { // AllowedIPs use real IPs
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { if err := d.removeAllowedIP(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) merr = multierror.Append(merr, err)
}
} }
} }
d.removeDNATMappings(toRemove)
} }
// Update domain prefixes using resolved domain as key // Update domain prefixes using resolved domain as key - store real IPs
if len(toAdd) > 0 || len(toRemove) > 0 { if len(toAdd) > 0 || len(toRemove) > 0 {
if d.route.KeepRoute { if d.route.KeepRoute {
// replace stored prefixes with old + added
// nolint:gocritic // nolint:gocritic
newPrefixes = append(oldPrefixes, toAdd...) newPrefixes = append(oldPrefixes, toAdd...)
} }
d.interceptedDomains[resolvedDomain] = newPrefixes d.interceptedDomains[resolvedDomain] = newPrefixes
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
// Store real IPs for status (user-facing), not fake IPs
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID()) d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
if len(toAdd) > 0 { d.logPrefixChanges(resolvedDomain, originalDomain, toAdd, toRemove)
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 && !d.route.KeepRoute {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// removeDNATMappings removes DNAT mappings from the firewall for real IP prefixes
func (d *DnsInterceptor) removeDNATMappings(realPrefixes []netip.Prefix) {
if len(realPrefixes) == 0 {
return
}
dnatFirewall, ok := d.internalDnatFw()
if !ok {
return
}
for _, prefix := range realPrefixes {
realIP := prefix.Addr()
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
if err := dnatFirewall.RemoveInternalDNATMapping(fakeIP); err != nil {
log.Errorf("Failed to remove DNAT mapping for %s: %v", fakeIP, err)
} else {
log.Debugf("Removed DNAT mapping for: %s -> %s", fakeIP, realIP)
}
}
}
}
// internalDnatFw checks if the firewall supports internal DNAT
func (d *DnsInterceptor) internalDnatFw() (internalDNATer, bool) {
if d.firewall == nil || runtime.GOOS != "android" {
return nil, false
}
fw, ok := d.firewall.(internalDNATer)
return fw, ok
}
// addDNATMappings adds DNAT mappings to the firewall
func (d *DnsInterceptor) addDNATMappings(mappings map[netip.Addr]netip.Addr) {
if len(mappings) == 0 {
return
}
dnatFirewall, ok := d.internalDnatFw()
if !ok {
return
}
for fakeIP, realIP := range mappings {
if err := dnatFirewall.AddInternalDNATMapping(fakeIP, realIP); err != nil {
log.Errorf("Failed to add DNAT mapping %s -> %s: %v", fakeIP, realIP, err)
} else {
log.Debugf("Added DNAT mapping: %s -> %s", fakeIP, realIP)
}
}
}
// cleanupDNATMappings removes all DNAT mappings for this interceptor
func (d *DnsInterceptor) cleanupDNATMappings() {
if _, ok := d.internalDnatFw(); !ok {
return
}
for _, prefixes := range d.interceptedDomains {
d.removeDNATMappings(prefixes)
}
}
// replaceIPsInDNSResponse replaces real IPs with fake IPs in the DNS response
func (d *DnsInterceptor) replaceIPsInDNSResponse(reply *dns.Msg, realPrefixes []netip.Prefix) {
if _, ok := d.internalDnatFw(); !ok {
return
}
// Replace A and AAAA records with fake IPs
for _, answer := range reply.Answer {
switch rr := answer.(type) {
case *dns.A:
realIP, ok := netip.AddrFromSlice(rr.A)
if !ok {
continue
}
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.A = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
case *dns.AAAA:
realIP, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
continue
}
if fakeIP, exists := d.fakeIPManager.GetFakeIP(realIP); exists {
rr.AAAA = fakeIP.AsSlice()
log.Tracef("Replaced real IP %s with fake IP %s in DNS response", realIP, fakeIP)
}
}
}
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) { func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool) prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes { for _, prefix := range oldPrefixes {

View File

@@ -14,6 +14,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/util"
@@ -52,24 +53,16 @@ type Route struct {
resolverAddr string resolverAddr string
} }
func NewRoute( func NewRoute(params common.HandlerParams, resolverAddr string) *Route {
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.WGIface,
resolverAddr string,
) *Route {
return &Route{ return &Route{
route: rt, route: params.Route,
routeRefCounter: routeRefCounter, routeRefCounter: params.RouteRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter, allowedIPsRefcounter: params.AllowedIPsRefCounter,
interval: interval, interval: params.DnsRouterInterval,
dynamicDomains: domainMap{}, statusRecorder: params.StatusRecorder,
statusRecorder: statusRecorder, wgInterface: params.WgInterface,
wgInterface: wgInterface,
resolverAddr: resolverAddr, resolverAddr: resolverAddr,
dynamicDomains: domainMap{},
} }
} }

View File

@@ -0,0 +1,93 @@
package fakeip
import (
"fmt"
"net/netip"
"sync"
)
// Manager manages allocation of fake IPs from the 240.0.0.0/8 block
type Manager struct {
mu sync.Mutex
nextIP netip.Addr // Next IP to allocate
allocated map[netip.Addr]netip.Addr // real IP -> fake IP
fakeToReal map[netip.Addr]netip.Addr // fake IP -> real IP
baseIP netip.Addr // First usable IP: 240.0.0.1
maxIP netip.Addr // Last usable IP: 240.255.255.254
}
// NewManager creates a new fake IP manager using 240.0.0.0/8 block
func NewManager() *Manager {
baseIP := netip.AddrFrom4([4]byte{240, 0, 0, 1})
maxIP := netip.AddrFrom4([4]byte{240, 255, 255, 254})
return &Manager{
nextIP: baseIP,
allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: baseIP,
maxIP: maxIP,
}
}
// AllocateFakeIP allocates a fake IP for the given real IP
// Returns the fake IP, or existing fake IP if already allocated
func (m *Manager) AllocateFakeIP(realIP netip.Addr) (netip.Addr, error) {
if !realIP.Is4() {
return netip.Addr{}, fmt.Errorf("only IPv4 addresses supported")
}
m.mu.Lock()
defer m.mu.Unlock()
if fakeIP, exists := m.allocated[realIP]; exists {
return fakeIP, nil
}
startIP := m.nextIP
for {
currentIP := m.nextIP
// Advance to next IP, wrapping at boundary
if m.nextIP.Compare(m.maxIP) >= 0 {
m.nextIP = m.baseIP
} else {
m.nextIP = m.nextIP.Next()
}
// Check if current IP is available
if _, inUse := m.fakeToReal[currentIP]; !inUse {
m.allocated[realIP] = currentIP
m.fakeToReal[currentIP] = realIP
return currentIP, nil
}
// Prevent infinite loop if all IPs exhausted
if m.nextIP.Compare(startIP) == 0 {
return netip.Addr{}, fmt.Errorf("no more fake IPs available in 240.0.0.0/8 block")
}
}
}
// GetFakeIP returns the fake IP for a real IP if it exists
func (m *Manager) GetFakeIP(realIP netip.Addr) (netip.Addr, bool) {
m.mu.Lock()
defer m.mu.Unlock()
fakeIP, exists := m.allocated[realIP]
return fakeIP, exists
}
// GetRealIP returns the real IP for a fake IP if it exists, otherwise false
func (m *Manager) GetRealIP(fakeIP netip.Addr) (netip.Addr, bool) {
m.mu.Lock()
defer m.mu.Unlock()
realIP, exists := m.fakeToReal[fakeIP]
return realIP, exists
}
// GetFakeIPBlock returns the fake IP block used by this manager
func (m *Manager) GetFakeIPBlock() netip.Prefix {
return netip.MustParsePrefix("240.0.0.0/8")
}

View File

@@ -0,0 +1,240 @@
package fakeip
import (
"net/netip"
"sync"
"testing"
)
func TestNewManager(t *testing.T) {
manager := NewManager()
if manager.baseIP.String() != "240.0.0.1" {
t.Errorf("Expected base IP 240.0.0.1, got %s", manager.baseIP.String())
}
if manager.maxIP.String() != "240.255.255.254" {
t.Errorf("Expected max IP 240.255.255.254, got %s", manager.maxIP.String())
}
if manager.nextIP.Compare(manager.baseIP) != 0 {
t.Errorf("Expected nextIP to start at baseIP")
}
}
func TestAllocateFakeIP(t *testing.T) {
manager := NewManager()
realIP := netip.MustParseAddr("8.8.8.8")
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate fake IP: %v", err)
}
if !fakeIP.Is4() {
t.Error("Fake IP should be IPv4")
}
// Check it's in the correct range
if fakeIP.As4()[0] != 240 {
t.Errorf("Fake IP should be in 240.0.0.0/8 range, got %s", fakeIP.String())
}
// Should return same fake IP for same real IP
fakeIP2, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to get existing fake IP: %v", err)
}
if fakeIP.Compare(fakeIP2) != 0 {
t.Errorf("Expected same fake IP for same real IP, got %s and %s", fakeIP.String(), fakeIP2.String())
}
}
func TestAllocateFakeIPIPv6Rejection(t *testing.T) {
manager := NewManager()
realIPv6 := netip.MustParseAddr("2001:db8::1")
_, err := manager.AllocateFakeIP(realIPv6)
if err == nil {
t.Error("Expected error for IPv6 address")
}
}
func TestGetFakeIP(t *testing.T) {
manager := NewManager()
realIP := netip.MustParseAddr("1.1.1.1")
// Should not exist initially
_, exists := manager.GetFakeIP(realIP)
if exists {
t.Error("Fake IP should not exist before allocation")
}
// Allocate and check
expectedFakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate: %v", err)
}
fakeIP, exists := manager.GetFakeIP(realIP)
if !exists {
t.Error("Fake IP should exist after allocation")
}
if fakeIP.Compare(expectedFakeIP) != 0 {
t.Errorf("Expected %s, got %s", expectedFakeIP.String(), fakeIP.String())
}
}
func TestMultipleAllocations(t *testing.T) {
manager := NewManager()
allocations := make(map[netip.Addr]netip.Addr)
// Allocate multiple IPs
for i := 1; i <= 100; i++ {
realIP := netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate fake IP for %s: %v", realIP.String(), err)
}
// Check for duplicates
for _, existingFake := range allocations {
if fakeIP.Compare(existingFake) == 0 {
t.Errorf("Duplicate fake IP allocated: %s", fakeIP.String())
}
}
allocations[realIP] = fakeIP
}
// Verify all allocations can be retrieved
for realIP, expectedFake := range allocations {
actualFake, exists := manager.GetFakeIP(realIP)
if !exists {
t.Errorf("Missing allocation for %s", realIP.String())
}
if actualFake.Compare(expectedFake) != 0 {
t.Errorf("Mismatch for %s: expected %s, got %s", realIP.String(), expectedFake.String(), actualFake.String())
}
}
}
func TestGetFakeIPBlock(t *testing.T) {
manager := NewManager()
block := manager.GetFakeIPBlock()
expected := "240.0.0.0/8"
if block.String() != expected {
t.Errorf("Expected %s, got %s", expected, block.String())
}
}
func TestConcurrentAccess(t *testing.T) {
manager := NewManager()
const numGoroutines = 50
const allocationsPerGoroutine = 10
var wg sync.WaitGroup
results := make(chan netip.Addr, numGoroutines*allocationsPerGoroutine)
// Concurrent allocations
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(goroutineID int) {
defer wg.Done()
for j := 0; j < allocationsPerGoroutine; j++ {
realIP := netip.AddrFrom4([4]byte{192, 168, byte(goroutineID), byte(j)})
fakeIP, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Errorf("Failed to allocate in goroutine %d: %v", goroutineID, err)
return
}
results <- fakeIP
}
}(i)
}
wg.Wait()
close(results)
// Check for duplicates
seen := make(map[netip.Addr]bool)
count := 0
for fakeIP := range results {
if seen[fakeIP] {
t.Errorf("Duplicate fake IP in concurrent test: %s", fakeIP.String())
}
seen[fakeIP] = true
count++
}
if count != numGoroutines*allocationsPerGoroutine {
t.Errorf("Expected %d allocations, got %d", numGoroutines*allocationsPerGoroutine, count)
}
}
func TestIPExhaustion(t *testing.T) {
// Create a manager with limited range for testing
manager := &Manager{
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 3}), // Only 3 IPs available
}
// Allocate all available IPs
realIPs := []netip.Addr{
netip.MustParseAddr("1.0.0.1"),
netip.MustParseAddr("1.0.0.2"),
netip.MustParseAddr("1.0.0.3"),
}
for _, realIP := range realIPs {
_, err := manager.AllocateFakeIP(realIP)
if err != nil {
t.Fatalf("Failed to allocate fake IP: %v", err)
}
}
// Try to allocate one more - should fail
_, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.4"))
if err == nil {
t.Error("Expected exhaustion error")
}
}
func TestWrapAround(t *testing.T) {
// Create manager starting near the end of range
manager := &Manager{
nextIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
allocated: make(map[netip.Addr]netip.Addr),
fakeToReal: make(map[netip.Addr]netip.Addr),
baseIP: netip.AddrFrom4([4]byte{240, 0, 0, 1}),
maxIP: netip.AddrFrom4([4]byte{240, 0, 0, 254}),
}
// Allocate the last IP
fakeIP1, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.1"))
if err != nil {
t.Fatalf("Failed to allocate first IP: %v", err)
}
if fakeIP1.String() != "240.0.0.254" {
t.Errorf("Expected 240.0.0.254, got %s", fakeIP1.String())
}
// Next allocation should wrap around to the beginning
fakeIP2, err := manager.AllocateFakeIP(netip.MustParseAddr("1.0.0.2"))
if err != nil {
t.Fatalf("Failed to allocate second IP: %v", err)
}
if fakeIP2.String() != "240.0.0.1" {
t.Errorf("Expected 240.0.0.1 after wrap, got %s", fakeIP2.String())
}
}

View File

@@ -8,9 +8,11 @@ import (
"net/netip" "net/netip"
"net/url" "net/url"
"runtime" "runtime"
"slices"
"sync" "sync"
"time" "time"
"github.com/google/uuid"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@@ -24,6 +26,8 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/client" "github.com/netbirdio/netbird/client/internal/routemanager/client"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
"github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
@@ -40,7 +44,7 @@ import (
// Manager is a route manager interface // Manager is a route manager interface
type Manager interface { type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) Init() error
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
TriggerSelection(route.HAMap) TriggerSelection(route.HAMap)
@@ -49,7 +53,7 @@ type Manager interface {
GetClientRoutesWithNetID() map[route.NetID][]*route.Route GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener) SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error SetFirewall(firewall.Manager) error
Stop(stateManager *statemanager.Manager) Stop(stateManager *statemanager.Manager)
} }
@@ -63,6 +67,7 @@ type ManagerConfig struct {
InitialRoutes []*route.Route InitialRoutes []*route.Route
StateManager *statemanager.Manager StateManager *statemanager.Manager
DNSServer dns.Server DNSServer dns.Server
DNSFeatureFlag bool
PeerStore *peerstore.Store PeerStore *peerstore.Store
DisableClientRoutes bool DisableClientRoutes bool
DisableServerRoutes bool DisableServerRoutes bool
@@ -89,11 +94,13 @@ type DefaultManager struct {
// clientRoutes is the most recent list of clientRoutes received from the Management Service // clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap clientRoutes route.HAMap
dnsServer dns.Server dnsServer dns.Server
firewall firewall.Manager
peerStore *peerstore.Store peerStore *peerstore.Store
useNewDNSRoute bool useNewDNSRoute bool
disableClientRoutes bool disableClientRoutes bool
disableServerRoutes bool disableServerRoutes bool
activeRoutes map[route.HAUniqueID]client.RouteHandler activeRoutes map[route.HAUniqueID]client.RouteHandler
fakeIPManager *fakeip.Manager
} }
func NewManager(config ManagerConfig) *DefaultManager { func NewManager(config ManagerConfig) *DefaultManager {
@@ -129,11 +136,31 @@ func NewManager(config ManagerConfig) *DefaultManager {
} }
if runtime.GOOS == "android" { if runtime.GOOS == "android" {
cr := dm.initialClientRoutes(config.InitialRoutes) dm.setupAndroidRoutes(config)
dm.notifier.SetInitialClientRoutes(cr)
} }
return dm return dm
} }
func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
cr := m.initialClientRoutes(config.InitialRoutes)
routesForComparison := slices.Clone(cr)
if config.DNSFeatureFlag {
m.fakeIPManager = fakeip.NewManager()
id := uuid.NewString()
fakeIPRoute := &route.Route{
ID: route.ID(id),
Network: m.fakeIPManager.GetFakeIPBlock(),
NetID: route.NetID(id),
Peer: m.pubKey,
NetworkType: route.IPv4Network,
}
cr = append(cr, fakeIPRoute)
}
m.notifier.SetInitialClientRoutes(cr, routesForComparison)
}
func (m *DefaultManager) setupRefCounters(useNoop bool) { func (m *DefaultManager) setupRefCounters(useNoop bool) {
m.routeRefCounter = refcounter.New( m.routeRefCounter = refcounter.New(
@@ -174,11 +201,11 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) {
} }
// Init sets up the routing // Init sets up the routing
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (m *DefaultManager) Init() error {
m.routeSelector = m.initSelector() m.routeSelector = m.initSelector()
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes { if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
return nil, nil, nil return nil
} }
if err := m.sysOps.CleanupRouting(nil); err != nil { if err := m.sysOps.CleanupRouting(nil); err != nil {
@@ -192,13 +219,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
ips := resolveURLsToIPs(initialAddresses) ips := resolveURLsToIPs(initialAddresses)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager) if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
if err != nil { return fmt.Errorf("setup routing: %w", err)
return nil, nil, fmt.Errorf("setup routing: %w", err)
} }
log.Info("Routing setup complete") log.Info("Routing setup complete")
return beforePeerHook, afterPeerHook, nil return nil
} }
func (m *DefaultManager) initSelector() *routeselector.RouteSelector { func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
@@ -222,16 +248,16 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
return routeselector.NewRouteSelector() return routeselector.NewRouteSelector()
} }
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { // SetFirewall sets the firewall manager for the DefaultManager
if m.disableServerRoutes { // Not thread-safe, should be called before starting the manager
func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
m.firewall = firewall
if m.disableServerRoutes || firewall == nil {
log.Info("server routes are disabled") log.Info("server routes are disabled")
return nil return nil
} }
if firewall == nil {
return errors.New("firewall manager is not set")
}
var err error var err error
m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) m.serverRouter, err = server.NewRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
if err != nil { if err != nil {
@@ -299,17 +325,20 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
} }
for id, route := range toAdd { for id, route := range toAdd {
handler := client.HandlerFromRoute( params := common.HandlerParams{
route, Route: route,
m.routeRefCounter, RouteRefCounter: m.routeRefCounter,
m.allowedIPsRefCounter, AllowedIPsRefCounter: m.allowedIPsRefCounter,
m.dnsRouteInterval, DnsRouterInterval: m.dnsRouteInterval,
m.statusRecorder, StatusRecorder: m.statusRecorder,
m.wgInterface, WgInterface: m.wgInterface,
m.dnsServer, DnsServer: m.dnsServer,
m.peerStore, PeerStore: m.peerStore,
m.useNewDNSRoute, UseNewDNSRoute: m.useNewDNSRoute,
) Firewall: m.firewall,
FakeIPManager: m.fakeIPManager,
}
handler := client.HandlerFromRoute(params)
if err := handler.AddRoute(m.ctx); err != nil { if err := handler.AddRoute(m.ctx); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err)) merr = multierror.Append(merr, fmt.Errorf("add route %s: %w", handler.String(), err))
continue continue
@@ -517,6 +546,7 @@ func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*ro
for _, routes := range crMap { for _, routes := range crMap {
rs = append(rs, routes...) rs = append(rs, routes...)
} }
return rs return rs
} }

View File

@@ -430,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
StatusRecorder: statusRecorder, StatusRecorder: statusRecorder,
}) })
_, _, err = routeManager.Init() err = routeManager.Init()
require.NoError(t, err, "should init route manager") require.NoError(t, err, "should init route manager")
defer routeManager.Stop(nil) defer routeManager.Stop(nil)

View File

@@ -9,7 +9,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
) )
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
@@ -23,8 +22,8 @@ type MockManager struct {
StopFunc func(manager *statemanager.Manager) StopFunc func(manager *statemanager.Manager)
} }
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { func (m *MockManager) Init() error {
return nil, nil, nil return nil
} }
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface // InitialRouteRange mock implementation of InitialRouteRange from Manager interface
@@ -87,7 +86,7 @@ func (m *MockManager) SetRouteChangeListener(listener listener.NetworkChangeList
} }
func (m *MockManager) EnableServerRouter(firewall firewall.Manager) error { func (m *MockManager) SetFirewall(firewall.Manager) error {
panic("implement me") panic("implement me")
} }

View File

@@ -1,124 +0,0 @@
package notifier
import (
"net/netip"
"runtime"
"sort"
"strings"
"sync"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/route"
)
type Notifier struct {
initialRouteRanges []string
routeRanges []string
listener listener.NetworkChangeListener
listenerMux sync.Mutex
}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0)
for _, r := range clientRoutes {
if r.IsDynamic() {
continue
}
nets = append(nets, r.Network.String())
}
sort.Strings(nets)
n.initialRouteRanges = nets
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" {
return
}
var newNets []string
for _, routes := range idMap {
for _, r := range routes {
if r.IsDynamic() {
continue
}
newNets = append(newNets, r.Network.String())
}
}
sort.Strings(newNets)
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
n.routeRanges = newNets
n.notify()
}
// OnNewPrefixes is called from iOS only
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
if !n.hasDiff(n.routeRanges, newNets) {
return
}
n.routeRanges = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(addIPv6RangeIfNeeded(n.routeRanges), ","))
}(n.listener)
}
func (n *Notifier) hasDiff(a []string, b []string) bool {
if len(a) != len(b) {
return true
}
for i, v := range a {
if v != b[i] {
return true
}
}
return false
}
func (n *Notifier) GetInitialRouteRanges() []string {
return addIPv6RangeIfNeeded(n.initialRouteRanges)
}
// addIPv6RangeIfNeeded returns the input ranges with the default IPv6 range when there is an IPv4 default route.
func addIPv6RangeIfNeeded(inputRanges []string) []string {
ranges := inputRanges
for _, r := range inputRanges {
// we are intentionally adding the ipv6 default range in case of ipv4 default range
// to ensure that all traffic is managed by the tunnel interface on android
if r == "0.0.0.0/0" {
ranges = append(ranges, "::/0")
break
}
}
return ranges
}

View File

@@ -0,0 +1,127 @@
//go:build android
package notifier
import (
"net/netip"
"slices"
"sort"
"strings"
"sync"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/route"
)
type Notifier struct {
initialRoutes []*route.Route
currentRoutes []*route.Route
listener listener.NetworkChangeListener
listenerMux sync.Mutex
}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *Notifier) SetInitialClientRoutes(initialRoutes []*route.Route, routesForComparison []*route.Route) {
// initialRoutes contains fake IP block for interface configuration
filteredInitial := make([]*route.Route, 0)
for _, r := range initialRoutes {
if r.IsDynamic() {
continue
}
filteredInitial = append(filteredInitial, r)
}
n.initialRoutes = filteredInitial
// routesForComparison excludes fake IP block for comparison with new routes
filteredComparison := make([]*route.Route, 0)
for _, r := range routesForComparison {
if r.IsDynamic() {
continue
}
filteredComparison = append(filteredComparison, r)
}
n.currentRoutes = filteredComparison
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
var newRoutes []*route.Route
for _, routes := range idMap {
for _, r := range routes {
if r.IsDynamic() {
continue
}
newRoutes = append(newRoutes, r)
}
}
if !n.hasRouteDiff(n.currentRoutes, newRoutes) {
return
}
n.currentRoutes = newRoutes
n.notify()
}
func (n *Notifier) OnNewPrefixes([]netip.Prefix) {
// Not used on Android
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
routeStrings := n.routesToStrings(n.currentRoutes)
sort.Strings(routeStrings)
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(routeStrings, n.currentRoutes), ","))
}(n.listener)
}
func (n *Notifier) routesToStrings(routes []*route.Route) []string {
nets := make([]string, 0, len(routes))
for _, r := range routes {
nets = append(nets, r.NetString())
}
return nets
}
func (n *Notifier) hasRouteDiff(a []*route.Route, b []*route.Route) bool {
slices.SortFunc(a, func(x, y *route.Route) int {
return strings.Compare(x.NetString(), y.NetString())
})
slices.SortFunc(b, func(x, y *route.Route) int {
return strings.Compare(x.NetString(), y.NetString())
})
return !slices.EqualFunc(a, b, func(x, y *route.Route) bool {
return x.NetString() == y.NetString()
})
}
func (n *Notifier) GetInitialRouteRanges() []string {
initialStrings := n.routesToStrings(n.initialRoutes)
sort.Strings(initialStrings)
return n.addIPv6RangeIfNeeded(initialStrings, n.initialRoutes)
}
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string, routes []*route.Route) []string {
for _, r := range routes {
if r.Network.Addr().Is4() && r.Network.Bits() == 0 {
return append(slices.Clone(inputRanges), "::/0")
}
}
return inputRanges
}

View File

@@ -0,0 +1,80 @@
//go:build ios
package notifier
import (
"net/netip"
"slices"
"sort"
"strings"
"sync"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/route"
)
type Notifier struct {
currentPrefixes []string
listener listener.NetworkChangeListener
listenerMux sync.Mutex
}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.listener = listener
}
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// iOS doesn't care about initial routes
}
func (n *Notifier) OnNewRoutes(route.HAMap) {
// Not used on iOS
}
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
if slices.Equal(n.currentPrefixes, newNets) {
return
}
n.currentPrefixes = newNets
n.notify()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.addIPv6RangeIfNeeded(n.currentPrefixes), ","))
}(n.listener)
}
func (n *Notifier) GetInitialRouteRanges() []string {
return nil
}
func (n *Notifier) addIPv6RangeIfNeeded(inputRanges []string) []string {
for _, r := range inputRanges {
if r == "0.0.0.0/0" {
return append(slices.Clone(inputRanges), "::/0")
}
}
return inputRanges
}

View File

@@ -0,0 +1,36 @@
//go:build !android && !ios
package notifier
import (
"net/netip"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/route"
)
type Notifier struct{}
func NewNotifier() *Notifier {
return &Notifier{}
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
// Not used on non-mobile platforms
}
func (n *Notifier) SetInitialClientRoutes([]*route.Route, []*route.Route) {
// Not used on non-mobile platforms
}
func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
// Not used on non-mobile platforms
}
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
// Not used on non-mobile platforms
}
func (n *Notifier) GetInitialRouteRanges() []string {
return []string{}
}

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -16,11 +17,11 @@ type Route struct {
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
} }
func NewRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *Route { func NewRoute(params common.HandlerParams) *Route {
return &Route{ return &Route{
route: rt, route: params.Route,
routeRefCounter: routeRefCounter, routeRefCounter: params.RouteRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter, allowedIPsRefcounter: params.AllowedIPsRefCounter,
} }
} }

View File

@@ -5,6 +5,8 @@ import (
"net" "net"
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic"
"time"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/notifier"
@@ -52,6 +54,13 @@ type SysOps struct {
mu sync.Mutex mu sync.Mutex
// notifier is used to notify the system of route changes (also used on mobile) // notifier is used to notify the system of route changes (also used on mobile)
notifier *notifier.Notifier notifier *notifier.Notifier
// seq is an atomic counter for generating unique sequence numbers for route messages
//nolint:unused // only used on BSD systems
seq atomic.Uint32
localSubnetsCache []*net.IPNet
localSubnetsCacheMu sync.RWMutex
localSubnetsCacheTime time.Time
} }
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
@@ -61,6 +70,11 @@ func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
} }
} }
//nolint:unused // only used on BSD systems
func (r *SysOps) getSeq() int {
return int(r.seq.Add(1))
}
func (r *SysOps) validateRoute(prefix netip.Prefix) error { func (r *SysOps) validateRoute(prefix netip.Prefix) error {
addr := prefix.Addr() addr := prefix.Addr()

View File

@@ -10,11 +10,10 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
return nil, nil, nil return nil
} }
func (r *SysOps) CleanupRouting(*statemanager.Manager) error { func (r *SysOps) CleanupRouting(*statemanager.Manager) error {

View File

@@ -10,6 +10,7 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"strconv" "strconv"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/libp2p/go-netroute" "github.com/libp2p/go-netroute"
@@ -24,6 +25,8 @@ import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const localSubnetsCacheTTL = 15 * time.Minute
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
@@ -31,7 +34,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate") var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{}) stateManager.RegisterState(&ShutdownState{})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
@@ -75,7 +78,10 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
r.refCounter = refCounter r.refCounter = refCounter
return r.setupHooks(initAddresses, stateManager) if err := r.setupHooks(initAddresses, stateManager); err != nil {
return fmt.Errorf("setup hooks: %w", err)
}
return nil
} }
// updateState updates state on every change so it will be persisted regularly // updateState updates state on every change so it will be persisted regularly
@@ -128,18 +134,14 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
return Nexthop{}, fmt.Errorf("get next hop: %w", err) return Nexthop{}, fmt.Errorf("get next hop: %w", err)
} }
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP) log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.Intf)
exitNextHop := Nexthop{ exitNextHop := nexthop
IP: nexthop.IP,
Intf: nexthop.Intf,
}
vpnAddr := vpnIntf.Address().IP vpnAddr := vpnIntf.Address().IP
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop) log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
exitNextHop = initialNextHop exitNextHop = initialNextHop
} }
@@ -152,12 +154,37 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
} }
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) { func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
r.localSubnetsCacheMu.RLock()
cacheAge := time.Since(r.localSubnetsCacheTime)
subnets := r.localSubnetsCache
r.localSubnetsCacheMu.RUnlock()
if cacheAge > localSubnetsCacheTTL || subnets == nil {
r.localSubnetsCacheMu.Lock()
if time.Since(r.localSubnetsCacheTime) > localSubnetsCacheTTL || r.localSubnetsCache == nil {
r.refreshLocalSubnetsCache()
}
subnets = r.localSubnetsCache
r.localSubnetsCacheMu.Unlock()
}
for _, subnet := range subnets {
if subnet.Contains(prefix.Addr().AsSlice()) {
return true, subnet
}
}
return false, nil
}
func (r *SysOps) refreshLocalSubnetsCache() {
localInterfaces, err := net.Interfaces() localInterfaces, err := net.Interfaces()
if err != nil { if err != nil {
log.Errorf("Failed to get local interfaces: %v", err) log.Errorf("Failed to get local interfaces: %v", err)
return false, nil return
} }
var newSubnets []*net.IPNet
for _, intf := range localInterfaces { for _, intf := range localInterfaces {
addrs, err := intf.Addrs() addrs, err := intf.Addrs()
if err != nil { if err != nil {
@@ -171,14 +198,12 @@ func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet)
log.Errorf("Failed to convert address to IPNet: %v", addr) log.Errorf("Failed to convert address to IPNet: %v", addr)
continue continue
} }
newSubnets = append(newSubnets, ipnet)
if ipnet.Contains(prefix.Addr().AsSlice()) {
return true, ipnet
}
} }
} }
return false, nil r.localSubnetsCache = newSubnets
r.localSubnetsCacheTime = time.Now()
} }
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
@@ -264,7 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop) return r.removeFromRouteTable(prefix, nextHop)
} }
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip) prefix, err := util.GetPrefixFromIP(ip)
if err != nil { if err != nil {
@@ -289,9 +314,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return nil return nil
} }
var merr *multierror.Error
for _, ip := range initAddresses { for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil { if err := beforeHook("init", ip); err != nil {
log.Errorf("Failed to add route reference: %v", err) merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
} }
} }
@@ -300,11 +327,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return ctx.Err() return ctx.Err()
} }
var result *multierror.Error var merr *multierror.Error
for _, ip := range resolvedIPs { for _, ip := range resolvedIPs {
result = multierror.Append(result, beforeHook(connID, ip.IP)) merr = multierror.Append(merr, beforeHook(connID, ip.IP))
} }
return nberrors.FormatErrorOrNil(result) return nberrors.FormatErrorOrNil(merr)
}) })
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
@@ -319,7 +346,16 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return afterHook(connID) return afterHook(connID)
}) })
return beforeHook, afterHook, nil nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.Decrement(prefix); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
r.updateState(stateManager)
return nil
})
return nberrors.FormatErrorOrNil(merr)
} }
func GetNextHop(ip netip.Addr) (Nexthop, error) { func GetNextHop(ip netip.Addr) (Nexthop, error) {

View File

@@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))
@@ -341,7 +341,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))
@@ -484,7 +484,7 @@ func setupTestEnv(t *testing.T) {
}) })
r := NewSysOps(wgInterface, nil) r := NewSysOps(wgInterface, nil)
_, _, err := r.SetupRouting(nil, nil) err := r.SetupRouting(nil, nil)
require.NoError(t, err, "setupRouting should not return err") require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() { t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil)) assert.NoError(t, r.CleanupRouting(nil))

View File

@@ -10,14 +10,13 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{}) r.prefixes = make(map[netip.Prefix]struct{})
return nil, nil, nil return nil
} }
func (r *SysOps) CleanupRouting(*statemanager.Manager) error { func (r *SysOps) CleanupRouting(*statemanager.Manager) error {

View File

@@ -72,7 +72,7 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured, // This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity. // enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
if !nbnet.AdvancedRouting() { if !nbnet.AdvancedRouting() {
log.Infof("Using legacy routing setup") log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
@@ -89,7 +89,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
rules := getSetupRules() rules := getSetupRules()
for _, rule := range rules { for _, rule := range rules {
if err := addRule(rule); err != nil { if err := addRule(rule); err != nil {
return nil, nil, fmt.Errorf("%s: %w", rule.description, err) return fmt.Errorf("%s: %w", rule.description, err)
} }
} }
@@ -104,7 +104,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
} }
originalSysctl = originalValues originalSysctl = originalValues
return nil, nil, nil return nil
} }
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.

View File

@@ -252,7 +252,7 @@ func TestSysOps_validateRoute_InvalidPrefix(t *testing.T) {
IP: wgNetwork.Addr(), IP: wgNetwork.Addr(),
Network: wgNetwork, Network: wgNetwork,
}, },
name: "wt0", name: "nb0",
} }
sysOps := &SysOps{ sysOps := &SysOps{

View File

@@ -18,10 +18,9 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }
@@ -108,7 +107,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
Type: action, Type: action,
Flags: unix.RTF_UP, Flags: unix.RTF_UP,
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
Seq: 1, Seq: r.getSeq(),
} }
const numAddrs = unix.RTAX_NETMASK + 1 const numAddrs = unix.RTAX_NETMASK + 1

View File

@@ -19,7 +19,6 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const InfiniteLifetime = 0xffffffff const InfiniteLifetime = 0xffffffff
@@ -137,7 +136,7 @@ const (
RouteDeleted RouteDeleted
) )
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }

View File

@@ -1330,6 +1330,13 @@ func (x *PeerState) GetRelayAddress() string {
return "" return ""
} }
func (x *PeerState) GetConnectionType() string {
if x.Relayed {
return "Relayed"
}
return "P2P"
}
// LocalPeerState contains the latest state of the local peer // LocalPeerState contains the latest state of the local peer
type LocalPeerState struct { type LocalPeerState struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
@@ -2290,6 +2297,7 @@ type DebugBundleRequest struct {
Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"`
SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"` SystemInfo bool `protobuf:"varint,3,opt,name=systemInfo,proto3" json:"systemInfo,omitempty"`
UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"` UploadURL string `protobuf:"bytes,4,opt,name=uploadURL,proto3" json:"uploadURL,omitempty"`
LogFileCount uint32 `protobuf:"varint,5,opt,name=logFileCount,proto3" json:"logFileCount,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@@ -2352,6 +2360,13 @@ func (x *DebugBundleRequest) GetUploadURL() string {
return "" return ""
} }
func (x *DebugBundleRequest) GetLogFileCount() uint32 {
if x != nil {
return x.LogFileCount
}
return 0
}
type DebugBundleResponse struct { type DebugBundleResponse struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"` Path string `protobuf:"bytes,1,opt,name=path,proto3" json:"path,omitempty"`
@@ -3746,14 +3761,15 @@ const file_daemon_proto_rawDesc = "" +
"\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" + "\x12translatedHostname\x18\x04 \x01(\tR\x12translatedHostname\x128\n" +
"\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" + "\x0etranslatedPort\x18\x05 \x01(\v2\x10.daemon.PortInfoR\x0etranslatedPort\"G\n" +
"\x17ForwardingRulesResponse\x12,\n" + "\x17ForwardingRulesResponse\x12,\n" +
"\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\x88\x01\n" + "\x05rules\x18\x01 \x03(\v2\x16.daemon.ForwardingRuleR\x05rules\"\xac\x01\n" +
"\x12DebugBundleRequest\x12\x1c\n" + "\x12DebugBundleRequest\x12\x1c\n" +
"\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x16\n" + "\tanonymize\x18\x01 \x01(\bR\tanonymize\x12\x16\n" +
"\x06status\x18\x02 \x01(\tR\x06status\x12\x1e\n" + "\x06status\x18\x02 \x01(\tR\x06status\x12\x1e\n" +
"\n" + "\n" +
"systemInfo\x18\x03 \x01(\bR\n" + "systemInfo\x18\x03 \x01(\bR\n" +
"systemInfo\x12\x1c\n" + "systemInfo\x12\x1c\n" +
"\tuploadURL\x18\x04 \x01(\tR\tuploadURL\"}\n" + "\tuploadURL\x18\x04 \x01(\tR\tuploadURL\x12\"\n" +
"\flogFileCount\x18\x05 \x01(\rR\flogFileCount\"}\n" +
"\x13DebugBundleResponse\x12\x12\n" + "\x13DebugBundleResponse\x12\x12\n" +
"\x04path\x18\x01 \x01(\tR\x04path\x12 \n" + "\x04path\x18\x01 \x01(\tR\x04path\x12 \n" +
"\vuploadedKey\x18\x02 \x01(\tR\vuploadedKey\x120\n" + "\vuploadedKey\x18\x02 \x01(\tR\vuploadedKey\x120\n" +

View File

@@ -356,6 +356,7 @@ message DebugBundleRequest {
string status = 2; string status = 2;
bool systemInfo = 3; bool systemInfo = 3;
string uploadURL = 4; string uploadURL = 4;
uint32 logFileCount = 5;
} }
message DebugBundleResponse { message DebugBundleResponse {

View File

@@ -1,8 +1,4 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.5.1
// - protoc v5.29.3
// source: daemon.proto
package proto package proto
@@ -15,31 +11,8 @@ import (
// This is a compile-time assertion to ensure that this generated file // This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against. // is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.64.0 or later. // Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion9 const _ = grpc.SupportPackageIsVersion7
const (
DaemonService_Login_FullMethodName = "/daemon.DaemonService/Login"
DaemonService_WaitSSOLogin_FullMethodName = "/daemon.DaemonService/WaitSSOLogin"
DaemonService_Up_FullMethodName = "/daemon.DaemonService/Up"
DaemonService_Status_FullMethodName = "/daemon.DaemonService/Status"
DaemonService_Down_FullMethodName = "/daemon.DaemonService/Down"
DaemonService_GetConfig_FullMethodName = "/daemon.DaemonService/GetConfig"
DaemonService_ListNetworks_FullMethodName = "/daemon.DaemonService/ListNetworks"
DaemonService_SelectNetworks_FullMethodName = "/daemon.DaemonService/SelectNetworks"
DaemonService_DeselectNetworks_FullMethodName = "/daemon.DaemonService/DeselectNetworks"
DaemonService_ForwardingRules_FullMethodName = "/daemon.DaemonService/ForwardingRules"
DaemonService_DebugBundle_FullMethodName = "/daemon.DaemonService/DebugBundle"
DaemonService_GetLogLevel_FullMethodName = "/daemon.DaemonService/GetLogLevel"
DaemonService_SetLogLevel_FullMethodName = "/daemon.DaemonService/SetLogLevel"
DaemonService_ListStates_FullMethodName = "/daemon.DaemonService/ListStates"
DaemonService_CleanState_FullMethodName = "/daemon.DaemonService/CleanState"
DaemonService_DeleteState_FullMethodName = "/daemon.DaemonService/DeleteState"
DaemonService_SetNetworkMapPersistence_FullMethodName = "/daemon.DaemonService/SetNetworkMapPersistence"
DaemonService_TracePacket_FullMethodName = "/daemon.DaemonService/TracePacket"
DaemonService_SubscribeEvents_FullMethodName = "/daemon.DaemonService/SubscribeEvents"
DaemonService_GetEvents_FullMethodName = "/daemon.DaemonService/GetEvents"
)
// DaemonServiceClient is the client API for DaemonService service. // DaemonServiceClient is the client API for DaemonService service.
// //
@@ -80,7 +53,7 @@ type DaemonServiceClient interface {
// SetNetworkMapPersistence enables or disables network map persistence // SetNetworkMapPersistence enables or disables network map persistence
SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error)
TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error)
SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error)
GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error)
} }
@@ -93,9 +66,8 @@ func NewDaemonServiceClient(cc grpc.ClientConnInterface) DaemonServiceClient {
} }
func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) { func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(LoginResponse) out := new(LoginResponse)
err := c.cc.Invoke(ctx, DaemonService_Login_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/Login", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -103,9 +75,8 @@ func (c *daemonServiceClient) Login(ctx context.Context, in *LoginRequest, opts
} }
func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLoginRequest, opts ...grpc.CallOption) (*WaitSSOLoginResponse, error) { func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLoginRequest, opts ...grpc.CallOption) (*WaitSSOLoginResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(WaitSSOLoginResponse) out := new(WaitSSOLoginResponse)
err := c.cc.Invoke(ctx, DaemonService_WaitSSOLogin_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitSSOLogin", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -113,9 +84,8 @@ func (c *daemonServiceClient) WaitSSOLogin(ctx context.Context, in *WaitSSOLogin
} }
func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error) { func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grpc.CallOption) (*UpResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(UpResponse) out := new(UpResponse)
err := c.cc.Invoke(ctx, DaemonService_Up_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/Up", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -123,9 +93,8 @@ func (c *daemonServiceClient) Up(ctx context.Context, in *UpRequest, opts ...grp
} }
func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error) { func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(StatusResponse) out := new(StatusResponse)
err := c.cc.Invoke(ctx, DaemonService_Status_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/Status", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -133,9 +102,8 @@ func (c *daemonServiceClient) Status(ctx context.Context, in *StatusRequest, opt
} }
func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) { func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(DownResponse) out := new(DownResponse)
err := c.cc.Invoke(ctx, DaemonService_Down_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/Down", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -143,9 +111,8 @@ func (c *daemonServiceClient) Down(ctx context.Context, in *DownRequest, opts ..
} }
func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) { func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(GetConfigResponse) out := new(GetConfigResponse)
err := c.cc.Invoke(ctx, DaemonService_GetConfig_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetConfig", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -153,9 +120,8 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques
} }
func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) { func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ListNetworksResponse) out := new(ListNetworksResponse)
err := c.cc.Invoke(ctx, DaemonService_ListNetworks_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListNetworks", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -163,9 +129,8 @@ func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworks
} }
func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(SelectNetworksResponse) out := new(SelectNetworksResponse)
err := c.cc.Invoke(ctx, DaemonService_SelectNetworks_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectNetworks", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -173,9 +138,8 @@ func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetw
} }
func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(SelectNetworksResponse) out := new(SelectNetworksResponse)
err := c.cc.Invoke(ctx, DaemonService_DeselectNetworks_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectNetworks", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -183,9 +147,8 @@ func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNe
} }
func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*ForwardingRulesResponse, error) { func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequest, opts ...grpc.CallOption) (*ForwardingRulesResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ForwardingRulesResponse) out := new(ForwardingRulesResponse)
err := c.cc.Invoke(ctx, DaemonService_ForwardingRules_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/ForwardingRules", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -193,9 +156,8 @@ func (c *daemonServiceClient) ForwardingRules(ctx context.Context, in *EmptyRequ
} }
func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) { func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(DebugBundleResponse) out := new(DebugBundleResponse)
err := c.cc.Invoke(ctx, DaemonService_DebugBundle_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/DebugBundle", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -203,9 +165,8 @@ func (c *daemonServiceClient) DebugBundle(ctx context.Context, in *DebugBundleRe
} }
func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) { func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(GetLogLevelResponse) out := new(GetLogLevelResponse)
err := c.cc.Invoke(ctx, DaemonService_GetLogLevel_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetLogLevel", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -213,9 +174,8 @@ func (c *daemonServiceClient) GetLogLevel(ctx context.Context, in *GetLogLevelRe
} }
func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) { func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(SetLogLevelResponse) out := new(SetLogLevelResponse)
err := c.cc.Invoke(ctx, DaemonService_SetLogLevel_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetLogLevel", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -223,9 +183,8 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe
} }
func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) { func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ListStatesResponse) out := new(ListStatesResponse)
err := c.cc.Invoke(ctx, DaemonService_ListStates_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -233,9 +192,8 @@ func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequ
} }
func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) { func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(CleanStateResponse) out := new(CleanStateResponse)
err := c.cc.Invoke(ctx, DaemonService_CleanState_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -243,9 +201,8 @@ func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequ
} }
func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) { func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(DeleteStateResponse) out := new(DeleteStateResponse)
err := c.cc.Invoke(ctx, DaemonService_DeleteState_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -253,9 +210,8 @@ func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRe
} }
func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) { func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(SetNetworkMapPersistenceResponse) out := new(SetNetworkMapPersistenceResponse)
err := c.cc.Invoke(ctx, DaemonService_SetNetworkMapPersistence_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -263,22 +219,20 @@ func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *
} }
func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) { func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(TracePacketResponse) out := new(TracePacketResponse)
err := c.cc.Invoke(ctx, DaemonService_TracePacket_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/TracePacket", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return out, nil return out, nil
} }
func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (grpc.ServerStreamingClient[SystemEvent], error) { func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *SubscribeRequest, opts ...grpc.CallOption) (DaemonService_SubscribeEventsClient, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], "/daemon.DaemonService/SubscribeEvents", opts...)
stream, err := c.cc.NewStream(ctx, &DaemonService_ServiceDesc.Streams[0], DaemonService_SubscribeEvents_FullMethodName, cOpts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
x := &grpc.GenericClientStream[SubscribeRequest, SystemEvent]{ClientStream: stream} x := &daemonServiceSubscribeEventsClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil { if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err return nil, err
} }
@@ -288,13 +242,26 @@ func (c *daemonServiceClient) SubscribeEvents(ctx context.Context, in *Subscribe
return x, nil return x, nil
} }
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type DaemonService_SubscribeEventsClient interface {
type DaemonService_SubscribeEventsClient = grpc.ServerStreamingClient[SystemEvent] Recv() (*SystemEvent, error)
grpc.ClientStream
}
type daemonServiceSubscribeEventsClient struct {
grpc.ClientStream
}
func (x *daemonServiceSubscribeEventsClient) Recv() (*SystemEvent, error) {
m := new(SystemEvent)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) { func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsRequest, opts ...grpc.CallOption) (*GetEventsResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(GetEventsResponse) out := new(GetEventsResponse)
err := c.cc.Invoke(ctx, DaemonService_GetEvents_FullMethodName, in, out, cOpts...) err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetEvents", in, out, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -303,7 +270,7 @@ func (c *daemonServiceClient) GetEvents(ctx context.Context, in *GetEventsReques
// DaemonServiceServer is the server API for DaemonService service. // DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer // All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility. // for forward compatibility
type DaemonServiceServer interface { type DaemonServiceServer interface {
// Login uses setup key to prepare configuration for the daemon. // Login uses setup key to prepare configuration for the daemon.
Login(context.Context, *LoginRequest) (*LoginResponse, error) Login(context.Context, *LoginRequest) (*LoginResponse, error)
@@ -340,17 +307,14 @@ type DaemonServiceServer interface {
// SetNetworkMapPersistence enables or disables network map persistence // SetNetworkMapPersistence enables or disables network map persistence
SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error)
TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error)
SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error
GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error)
mustEmbedUnimplementedDaemonServiceServer() mustEmbedUnimplementedDaemonServiceServer()
} }
// UnimplementedDaemonServiceServer must be embedded to have // UnimplementedDaemonServiceServer must be embedded to have forward compatible implementations.
// forward compatible implementations. type UnimplementedDaemonServiceServer struct {
// }
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedDaemonServiceServer struct{}
func (UnimplementedDaemonServiceServer) Login(context.Context, *LoginRequest) (*LoginResponse, error) { func (UnimplementedDaemonServiceServer) Login(context.Context, *LoginRequest) (*LoginResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Login not implemented") return nil, status.Errorf(codes.Unimplemented, "method Login not implemented")
@@ -406,14 +370,13 @@ func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context
func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) { func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented") return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented")
} }
func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, grpc.ServerStreamingServer[SystemEvent]) error { func (UnimplementedDaemonServiceServer) SubscribeEvents(*SubscribeRequest, DaemonService_SubscribeEventsServer) error {
return status.Errorf(codes.Unimplemented, "method SubscribeEvents not implemented") return status.Errorf(codes.Unimplemented, "method SubscribeEvents not implemented")
} }
func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) { func (UnimplementedDaemonServiceServer) GetEvents(context.Context, *GetEventsRequest) (*GetEventsResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetEvents not implemented")
} }
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
func (UnimplementedDaemonServiceServer) testEmbeddedByValue() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to DaemonServiceServer will // Use of this interface is not recommended, as added methods to DaemonServiceServer will
@@ -423,13 +386,6 @@ type UnsafeDaemonServiceServer interface {
} }
func RegisterDaemonServiceServer(s grpc.ServiceRegistrar, srv DaemonServiceServer) { func RegisterDaemonServiceServer(s grpc.ServiceRegistrar, srv DaemonServiceServer) {
// If the following call pancis, it indicates UnimplementedDaemonServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&DaemonService_ServiceDesc, srv) s.RegisterService(&DaemonService_ServiceDesc, srv)
} }
@@ -443,7 +399,7 @@ func _DaemonService_Login_Handler(srv interface{}, ctx context.Context, dec func
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_Login_FullMethodName, FullMethod: "/daemon.DaemonService/Login",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).Login(ctx, req.(*LoginRequest)) return srv.(DaemonServiceServer).Login(ctx, req.(*LoginRequest))
@@ -461,7 +417,7 @@ func _DaemonService_WaitSSOLogin_Handler(srv interface{}, ctx context.Context, d
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_WaitSSOLogin_FullMethodName, FullMethod: "/daemon.DaemonService/WaitSSOLogin",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).WaitSSOLogin(ctx, req.(*WaitSSOLoginRequest)) return srv.(DaemonServiceServer).WaitSSOLogin(ctx, req.(*WaitSSOLoginRequest))
@@ -479,7 +435,7 @@ func _DaemonService_Up_Handler(srv interface{}, ctx context.Context, dec func(in
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_Up_FullMethodName, FullMethod: "/daemon.DaemonService/Up",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).Up(ctx, req.(*UpRequest)) return srv.(DaemonServiceServer).Up(ctx, req.(*UpRequest))
@@ -497,7 +453,7 @@ func _DaemonService_Status_Handler(srv interface{}, ctx context.Context, dec fun
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_Status_FullMethodName, FullMethod: "/daemon.DaemonService/Status",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).Status(ctx, req.(*StatusRequest)) return srv.(DaemonServiceServer).Status(ctx, req.(*StatusRequest))
@@ -515,7 +471,7 @@ func _DaemonService_Down_Handler(srv interface{}, ctx context.Context, dec func(
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_Down_FullMethodName, FullMethod: "/daemon.DaemonService/Down",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).Down(ctx, req.(*DownRequest)) return srv.(DaemonServiceServer).Down(ctx, req.(*DownRequest))
@@ -533,7 +489,7 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_GetConfig_FullMethodName, FullMethod: "/daemon.DaemonService/GetConfig",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetConfig(ctx, req.(*GetConfigRequest)) return srv.(DaemonServiceServer).GetConfig(ctx, req.(*GetConfigRequest))
@@ -551,7 +507,7 @@ func _DaemonService_ListNetworks_Handler(srv interface{}, ctx context.Context, d
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_ListNetworks_FullMethodName, FullMethod: "/daemon.DaemonService/ListNetworks",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest)) return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest))
@@ -569,7 +525,7 @@ func _DaemonService_SelectNetworks_Handler(srv interface{}, ctx context.Context,
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_SelectNetworks_FullMethodName, FullMethod: "/daemon.DaemonService/SelectNetworks",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest)) return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest))
@@ -587,7 +543,7 @@ func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Contex
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_DeselectNetworks_FullMethodName, FullMethod: "/daemon.DaemonService/DeselectNetworks",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest)) return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest))
@@ -605,7 +561,7 @@ func _DaemonService_ForwardingRules_Handler(srv interface{}, ctx context.Context
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_ForwardingRules_FullMethodName, FullMethod: "/daemon.DaemonService/ForwardingRules",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ForwardingRules(ctx, req.(*EmptyRequest)) return srv.(DaemonServiceServer).ForwardingRules(ctx, req.(*EmptyRequest))
@@ -623,7 +579,7 @@ func _DaemonService_DebugBundle_Handler(srv interface{}, ctx context.Context, de
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_DebugBundle_FullMethodName, FullMethod: "/daemon.DaemonService/DebugBundle",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest)) return srv.(DaemonServiceServer).DebugBundle(ctx, req.(*DebugBundleRequest))
@@ -641,7 +597,7 @@ func _DaemonService_GetLogLevel_Handler(srv interface{}, ctx context.Context, de
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_GetLogLevel_FullMethodName, FullMethod: "/daemon.DaemonService/GetLogLevel",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetLogLevel(ctx, req.(*GetLogLevelRequest)) return srv.(DaemonServiceServer).GetLogLevel(ctx, req.(*GetLogLevelRequest))
@@ -659,7 +615,7 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_SetLogLevel_FullMethodName, FullMethod: "/daemon.DaemonService/SetLogLevel",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest)) return srv.(DaemonServiceServer).SetLogLevel(ctx, req.(*SetLogLevelRequest))
@@ -677,7 +633,7 @@ func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_ListStates_FullMethodName, FullMethod: "/daemon.DaemonService/ListStates",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest)) return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest))
@@ -695,7 +651,7 @@ func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_CleanState_FullMethodName, FullMethod: "/daemon.DaemonService/CleanState",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest)) return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest))
@@ -713,7 +669,7 @@ func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, de
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_DeleteState_FullMethodName, FullMethod: "/daemon.DaemonService/DeleteState",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest)) return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest))
@@ -731,7 +687,7 @@ func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx contex
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_SetNetworkMapPersistence_FullMethodName, FullMethod: "/daemon.DaemonService/SetNetworkMapPersistence",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest)) return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest))
@@ -749,7 +705,7 @@ func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, de
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_TracePacket_FullMethodName, FullMethod: "/daemon.DaemonService/TracePacket",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest)) return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest))
@@ -762,11 +718,21 @@ func _DaemonService_SubscribeEvents_Handler(srv interface{}, stream grpc.ServerS
if err := stream.RecvMsg(m); err != nil { if err := stream.RecvMsg(m); err != nil {
return err return err
} }
return srv.(DaemonServiceServer).SubscribeEvents(m, &grpc.GenericServerStream[SubscribeRequest, SystemEvent]{ServerStream: stream}) return srv.(DaemonServiceServer).SubscribeEvents(m, &daemonServiceSubscribeEventsServer{stream})
} }
// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. type DaemonService_SubscribeEventsServer interface {
type DaemonService_SubscribeEventsServer = grpc.ServerStreamingServer[SystemEvent] Send(*SystemEvent) error
grpc.ServerStream
}
type daemonServiceSubscribeEventsServer struct {
grpc.ServerStream
}
func (x *daemonServiceSubscribeEventsServer) Send(m *SystemEvent) error {
return x.ServerStream.SendMsg(m)
}
func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetEventsRequest) in := new(GetEventsRequest)
@@ -778,7 +744,7 @@ func _DaemonService_GetEvents_Handler(srv interface{}, ctx context.Context, dec
} }
info := &grpc.UnaryServerInfo{ info := &grpc.UnaryServerInfo{
Server: srv, Server: srv,
FullMethod: DaemonService_GetEvents_FullMethodName, FullMethod: "/daemon.DaemonService/GetEvents",
} }
handler := func(ctx context.Context, req interface{}) (interface{}, error) { handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).GetEvents(ctx, req.(*GetEventsRequest)) return srv.(DaemonServiceServer).GetEvents(ctx, req.(*GetEventsRequest))

Some files were not shown because too many files have changed in this diff Show More