Compare commits

...

27 Commits

Author SHA1 Message Date
Viktor Liu
0ef476b014 [client] Fix state dump panic (#3519) 2025-03-16 15:13:04 +01:00
Zoltan Papp
6f82e96d6a [client] Set info logs (#3504)
collect and log connection stats per peer every 10 minutes
2025-03-14 22:34:41 +01:00
Viktor Liu
a2faae5d62 [client] Fix anonymized addresses documentation (#3505) 2025-03-14 11:38:16 +01:00
Zoltan Papp
4a3cbcd38a Nil check on route manager (#3486) 2025-03-13 00:04:00 +01:00
Misha Bragin
c2980bc8cf Update link to kubernetes operator (#3489) 2025-03-12 21:18:19 +01:00
Pascal Fischer
67ae871ce4 [management] return empty array instead of null on networks endpoints (#3480) 2025-03-11 00:20:54 +01:00
Maycon Santos
39ff5e833a [misc] Update slack invite link (#3479) 2025-03-11 00:12:11 +01:00
Zoltan Papp
cd9eff5331 Increase the timeout to 50 sec (#3481) 2025-03-10 18:23:47 +01:00
Viktor Liu
80ceb80197 [client] Ignore candidates that are part of the the wireguard subnet (#3472) 2025-03-10 13:59:21 +01:00
Zoltan Papp
636a0e2475 [client] Fix engine restart (#3435)
- Refactor the network monitoring to handle one event and it after return
- In the engine restart cancel the upper layer context and the responsibility of the engine stop will be the upper layer
- Before triggering a restart, the engine checks whether the state is already down. This helps avoid unnecessary delayed network restart events.
2025-03-10 13:32:12 +01:00
Viktor Liu
e66e329bf6 [client] Add option to autostart netbird ui in the Windows installer (#3469) 2025-03-10 13:19:17 +01:00
Zoltan Papp
aaa23beeec [client] Prevent to block channel writing (#3474)
The "runningChan" provides feedback to the UI or any client about whether the service is up and running. If the client exits earlier than when the service successfully starts, then this channel causes a block.

- Added timeout for reading the channel to ensure we don't cause blocks for too long for the caller
- Modified channel writing operations to be non-blocking
2025-03-10 13:17:09 +01:00
Zoltan Papp
6bef474e9e [client] Prevent panic in case of double close call (#3475)
Prevent panic in case of double close call
2025-03-10 13:16:28 +01:00
Maycon Santos
81040ff80a [docs] Update typo (#3477) 2025-03-10 11:52:36 +01:00
Viktor Liu
c73481aee4 [client] Enable windows stderr logs by default (#3476) 2025-03-10 11:30:49 +01:00
Viktor Liu
fc1da94520 [client, management] Add port forwarding (#3275)
Add initial support to ingress ports on the client code.

- new types where added
- new protocol messages and controller
2025-03-09 16:06:43 +01:00
Muzammil
ae6b61301c Muz/netbird dashboards (#3458)
* added all 3 dashboards

* update readme
2025-03-07 16:13:11 +01:00
Philippe Vaucher
a444e551b3 [misc] Traefik config improvements (#3346)
* Remove deprecated docker-compose version

* Prettify docker-compose files

* Backports missing logging entries

* Fix signal port

* Add missing relay configuration

* Serve management over 33073 to avoid confusion
2025-03-07 16:10:11 +01:00
Zoltan Papp
53b9a2002f Print out the goroutine id (#3433)
The TXT logger prints out the actual go routine ID

This feature depends on 'loggoroutine' build tag

```go build -tags loggoroutine```
2025-03-07 14:06:47 +01:00
Zoltan Papp
4b76d93cec [client] Fix TURN-Relay switch (#3456)
- When a peer is connected with TURN and a Relay connection is established, do not force switching to Relay. Keep using TURN until disconnection.

-In the proxy preparation phase, the Bind Proxy does not set the remote conn as a fake address for Bind. When running the Work() function, the proper proxy instance updates the conn inside the Bind.
2025-03-07 12:00:25 +01:00
Viktor Liu
062d1ec76f [misc] Update bug-issue-report.md template (#3449) 2025-03-06 01:10:37 +01:00
Viktor Liu
c111675dd8 [client] Handle large DNS packets in dns route resolution (#3441) 2025-03-05 18:57:17 +01:00
hakansa
60ffe0dc87 [client] UI Refactor Icon Paths (#3420)
[client] UI Refactor Icon Paths (#3420)
2025-03-04 18:29:29 +03:00
Viktor Liu
bcc5824980 [client] Close userspace firewall properly (#3426) 2025-03-04 11:19:42 +01:00
robertgro
af5796de1c [client] Add Netbird GitHub link to the client ui about sub menu (#3372) 2025-03-03 17:32:50 +01:00
Philippe Vaucher
9d604b7e66 [client Fix env var typo (#3415) 2025-03-03 17:22:51 +01:00
Bethuel Mmbaga
82c12cc8ae [management] Handle transaction error on peer deletion (#3387)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-25 19:57:04 +00:00
225 changed files with 12137 additions and 1966 deletions

View File

@@ -31,14 +31,22 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version` `netbird version`
**NetBird status -dA output:** **Is any other VPN software installed?**
If applicable, add the `netbird status -dA' command output. If yes, which one?
**Do you face any (non-mobile) client issues?** **Debug output**
Please provide the file created by `netbird debug for 1m -AS`. To help us resolve the problem, please attach the following debug output
We advise reviewing the anonymized files for any remaining PII.
netbird status -dA
As well as the file created by
netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots** **Screenshots**
@@ -47,3 +55,10 @@ If applicable, add screenshots to help explain your problem.
**Additional context** **Additional context**
Add any other context about the problem here. Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -258,7 +258,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ] store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
@@ -325,8 +325,8 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
@@ -392,7 +392,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
@@ -461,7 +461,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:

View File

@@ -71,7 +71,7 @@ jobs:
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
@@ -150,7 +150,7 @@ jobs:
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4

View File

@@ -53,9 +53,9 @@ nfpms:
scripts: scripts:
postinstall: "release_files/ui-post-install.sh" postinstall: "release_files/ui-post-install.sh"
contents: contents:
- src: client/ui/netbird.desktop - src: client/ui/build/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird.png - src: client/ui/assets/netbird.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird
@@ -72,9 +72,9 @@ nfpms:
scripts: scripts:
postinstall: "release_files/ui-post-install.sh" postinstall: "release_files/ui-post-install.sh"
contents: contents:
- src: client/ui/netbird.desktop - src: client/ui/build/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird.png - src: client/ui/assets/netbird.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird

View File

@@ -12,7 +12,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
<br> <br>
@@ -29,13 +29,13 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a> Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
<br/> <br/>
</strong> </strong>
<br> <br>
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github"> <a href="https://github.com/netbirdio/kubernetes-operator">
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts New: NetBird Kubernetes Operator
</a> </a>
</p> </p>

View File

@@ -26,7 +26,7 @@ type Anonymizer struct {
} }
func DefaultAddresses() (netip.Addr, netip.Addr) { func DefaultAddresses() (netip.Addr, netip.Addr) {
// 192.51.100.0, 100:: // 198.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01}) return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
} }

View File

@@ -0,0 +1,98 @@
package cmd
import (
"fmt"
"sort"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var forwardingRulesCmd = &cobra.Command{
Use: "forwarding",
Short: "List forwarding rules",
Long: `Commands to list forwarding rules.`,
}
var forwardingRulesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List forwarding rules",
Example: " netbird forwarding list",
Long: "Commands to list forwarding rules.",
RunE: listForwardingRules,
}
func listForwardingRules(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ForwardingRules(cmd.Context(), &proto.EmptyRequest{})
if err != nil {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
if len(resp.GetRules()) == 0 {
cmd.Println("No forwarding rules available.")
return nil
}
printForwardingRules(cmd, resp.GetRules())
return nil
}
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
cmd.Println("Available forwarding rules:")
// Sort rules by translated address
sort.Slice(rules, func(i, j int) bool {
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
}
if rules[i].GetProtocol() != rules[j].GetProtocol() {
return rules[i].GetProtocol() < rules[j].GetProtocol()
}
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
})
var lastIP string
for _, rule := range rules {
dPort := portToString(rule.GetDestinationPort())
tPort := portToString(rule.GetTranslatedPort())
if lastIP != rule.GetTranslatedAddress() {
lastIP = rule.GetTranslatedAddress()
cmd.Printf("\nTranslated peer: %s\n", rule.GetTranslatedHostname())
}
cmd.Printf(" Local %s/%s to %s:%s\n", rule.GetProtocol(), dPort, rule.GetTranslatedAddress(), tPort)
}
}
func getFirstPort(portInfo *proto.PortInfo) int {
switch v := portInfo.PortSelection.(type) {
case *proto.PortInfo_Port:
return int(v.Port)
case *proto.PortInfo_Range_:
return int(v.Range.GetStart())
default:
return 0
}
}
func portToString(translatedPort *proto.PortInfo) string {
switch v := translatedPort.PortSelection.(type) {
case *proto.PortInfo_Port:
return fmt.Sprintf("%d", v.Port)
case *proto.PortInfo_Range_:
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
default:
return "No port specified"
}
}

View File

@@ -145,6 +145,7 @@ func init() {
rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd) rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(networksCMD) rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd) rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
@@ -153,6 +154,8 @@ func init() {
networksCMD.AddCommand(routesListCmd) networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
forwardingRulesCmd.AddCommand(forwardingRulesListCmd)
debugCmd.AddCommand(debugBundleCmd) debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd) debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd) logCmd.AddCommand(logLevelCmd)

View File

@@ -10,6 +10,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -89,7 +90,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics) accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -134,10 +134,11 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available // TODO: make after-startup backoff err available
run := make(chan error, 1) run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
go func() { go func() {
if err := client.Run(run); err != nil { if err := client.Run(run); err != nil {
run <- err clientErr <- err
} }
}() }()
@@ -147,13 +148,9 @@ func (c *Client) Start(startCtx context.Context) error {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err()) return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
} }
return startCtx.Err() return startCtx.Err()
case err := <-run: case err := <-clientErr:
if err != nil { return fmt.Errorf("startup: %w", err)
if stopErr := client.Stop(); stopErr != nil { case <-run:
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
}
return fmt.Errorf("startup: %w", err)
}
} }
c.connect = client c.connect = client

View File

@@ -4,12 +4,13 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string Name() string
Address() device.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice

View File

@@ -30,10 +30,8 @@ type entry struct {
} }
type aclManager struct { type aclManager struct {
iptablesClient *iptables.IPTables iptablesClient *iptables.IPTables
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string
entries aclEntries entries aclEntries
optionalEntries map[string][]entry optionalEntries map[string][]entry
ipsetStore *ipsetStore ipsetStore *ipsetStore
@@ -41,12 +39,10 @@ type aclManager struct {
stateManager *statemanager.Manager stateManager *statemanager.Manager
} }
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
m := &aclManager{ m := &aclManager{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
wgIface: wgIface, wgIface: wgIface,
routingFwChainName: routingFwChainName,
entries: make(map[string][][]string), entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry), optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
@@ -314,9 +310,12 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
// Inbound is handled by our ACLs, the rest is dropped.
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
} }
func (m *aclManager) seedInitialOptionalEntries() { func (m *aclManager) seedInitialOptionalEntries() {

View File

@@ -13,7 +13,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -31,7 +31,7 @@ type Manager struct {
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
} }
@@ -52,7 +52,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return nil, fmt.Errorf("create router: %w", err) return nil, fmt.Errorf("create router: %w", err)
} }
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) m.aclMgr, err = newAclManager(iptablesClient, wgIface)
if err != nil { if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err) return nil, fmt.Errorf("create acl manager: %w", err)
} }
@@ -166,7 +166,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -226,6 +226,22 @@ func (m *Manager) DisableRouting() error {
return nil return nil
} }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@@ -10,15 +10,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
AddressFunc func() iface.WGAddress AddressFunc func() wgaddr.Address
} }
func (i *iFaceMock) Name() string { func (i *iFaceMock) Name() string {
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set") panic("NameFunc is not set")
} }
func (i *iFaceMock) Address() iface.WGAddress { func (i *iFaceMock) Address() wgaddr.Address {
if i.AddressFunc != nil { if i.AddressFunc != nil {
return i.AddressFunc() return i.AddressFunc()
} }
@@ -62,7 +62,7 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Close(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -100,14 +100,14 @@ func TestIptablesManager(t *testing.T) {
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil) err = manager.Close(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
require.NoError(t, err, "failed check chain exists") require.NoError(t, err, "failed check chain exists")
if ok { if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules) require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
} }
}) })
} }
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -136,7 +136,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Close(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -166,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
err = manager.Reset(nil) err = manager.Close(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
}) })
} }
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -204,7 +204,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Close(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@@ -16,6 +16,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -23,22 +24,36 @@ import (
// constants needed to manage and create iptable rules // constants needed to manage and create iptable rules
const ( const (
tableFilter = "filter" tableFilter = "filter"
tableNat = "nat" tableNat = "nat"
tableMangle = "mangle" tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING" chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING" chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT" chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD" chainRTFWDIN = "NETBIRD-RT-FWD-IN"
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE" chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre" jumpManglePre = "jump-mangle-pre"
jumpNat = "jump-nat" jumpNatPre = "jump-nat-pre"
matchSet = "--match-set" jumpNatPost = "jump-nat-post"
matchSet = "--match-set"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
fwdSuffix = "_fwd"
) )
type ruleInfo struct {
chain string
table string
rule []string
}
type routeFilteringRuleParams struct { type routeFilteringRuleParams struct {
Sources []netip.Prefix Sources []netip.Prefix
Destination netip.Prefix Destination netip.Prefix
@@ -62,6 +77,7 @@ type router struct {
legacyManagement bool legacyManagement bool
stateManager *statemanager.Manager stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
} }
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
@@ -69,6 +85,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
@@ -139,9 +156,9 @@ func (r *router) AddRouteFiltering(
var err error var err error
if action == firewall.ActionDrop { if action == firewall.ActionDrop {
// after the established rule // after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...) err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
} else { } else {
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...) err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
} }
if err != nil { if err != nil {
@@ -156,12 +173,12 @@ func (r *router) AddRouteFiltering(
} }
func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.GetRuleID() ruleKey := rule.ID()
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule) setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil { if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err) return fmt.Errorf("delete route rule: %v", err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
@@ -212,6 +229,10 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain // AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement { if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil { if err := r.addLegacyRouteRule(pair); err != nil {
@@ -238,6 +259,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains // RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove nat rule: %w", err)
} }
@@ -264,7 +289,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
} }
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
@@ -277,7 +302,7 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
@@ -305,7 +330,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue continue
} }
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else { } else {
delete(r.rules, k) delete(r.rules, k)
@@ -343,9 +368,11 @@ func (r *router) cleanUpDefaultForwardRules() error {
chain string chain string
table string table string
}{ }{
{chainRTFWD, tableFilter}, {chainRTFWDIN, tableFilter},
{chainRTNAT, tableNat}, {chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle}, {chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
} { } {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil { if err != nil {
@@ -365,16 +392,22 @@ func (r *router) createContainers() error {
chain string chain string
table string table string
}{ }{
{chainRTFWD, tableFilter}, {chainRTFWDIN, tableFilter},
{chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle}, {chainRTPRE, tableMangle},
{chainRTNAT, tableNat}, {chainRTNAT, tableNat},
{chainRTRDR, tableNat},
} { } {
if err := r.createAndSetupChain(chainInfo.chain); err != nil { if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} }
} }
if err := r.insertEstablishedRule(chainRTFWD); err != nil { if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
return fmt.Errorf("insert established rule: %w", err) return fmt.Errorf("insert established rule: %w", err)
} }
@@ -415,27 +448,6 @@ func (r *router) addPostroutingRules() error {
return nil return nil
} }
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
if err := r.iptablesClient.NewChain(table, chain); err != nil {
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
}
return nil
}
func (r *router) getTableForChain(chain string) string {
switch chain {
case chainRTNAT:
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
}
func (r *router) insertEstablishedRule(chain string) error { func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished() establishedRule := getConntrackEstablished()
@@ -454,28 +466,43 @@ func (r *router) addJumpRules() error {
// Jump to NAT chain // Jump to NAT chain
natRule := []string{"-j", chainRTNAT} natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil { if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat jump rule: %v", err) return fmt.Errorf("add nat postrouting jump rule: %v", err)
} }
r.rules[jumpNat] = natRule r.rules[jumpNatPost] = natRule
// Jump to prerouting chain // Jump to mangle prerouting chain
preRule := []string{"-j", chainRTPRE} preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil { if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err) return fmt.Errorf("add mangle prerouting jump rule: %v", err)
} }
r.rules[jumpPre] = preRule r.rules[jumpManglePre] = preRule
// Jump to nat prerouting chain
rdrRule := []string{"-j", chainRTRDR}
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
return fmt.Errorf("add nat prerouting jump rule: %v", err)
}
r.rules[jumpNatPre] = rdrRule
return nil return nil
} }
func (r *router) cleanJumpRules() error { func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNat, jumpPre} { for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
table := tableNat var table, chain string
chain := chainPOSTROUTING switch ruleKey {
if ruleKey == jumpPre { case jumpNatPost:
table = tableNat
chain = chainPOSTROUTING
case jumpManglePre:
table = tableMangle table = tableMangle
chain = chainPREROUTING chain = chainPREROUTING
case jumpNatPre:
table = tableNat
chain = chainPREROUTING
default:
return fmt.Errorf("unknown jump rule: %s", ruleKey)
} }
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
@@ -520,6 +547,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
} }
r.rules[ruleKey] = rule r.rules[ruleKey] = rule
r.updateState()
return nil return nil
} }
@@ -535,6 +564,7 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
log.Debugf("marking rule %s not found", ruleKey) log.Debugf("marking rule %s not found", ruleKey)
} }
r.updateState()
return nil return nil
} }
@@ -564,6 +594,137 @@ func (r *router) updateState() {
} }
} }
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
toDestination := rule.TranslatedAddress.String()
switch {
case len(rule.TranslatedPort.Values) == 0:
// no translated port, use original port
case len(rule.TranslatedPort.Values) == 1:
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
// need the "/originalport" suffix to avoid dnat port randomization
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
default:
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
proto := strings.ToLower(string(rule.Protocol))
rules := make(map[string]ruleInfo, 3)
// DNAT rule
dnatRule := []string{
"!", "-i", r.wgIface.Name(),
"-p", proto,
"-j", "DNAT",
"--to-destination", toDestination,
}
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
rules[ruleKey+dnatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
// SNAT rule
snatRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "MASQUERADE",
}
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+snatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTNAT,
rule: snatRule,
}
// Forward filtering rule, if fwd policy is DROP
forwardRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "ACCEPT",
}
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+fwdSuffix] = ruleInfo{
table: tableFilter,
chain: chainRTFWDOUT,
rule: forwardRule,
}
for key, ruleInfo := range rules {
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
log.Errorf("rollback failed: %v", rollbackErr)
}
return nil, fmt.Errorf("add rule %s: %w", key, err)
}
r.rules[key] = ruleInfo.rule
}
r.updateState()
return rule, nil
}
func (r *router) rollbackRules(rules map[string]ruleInfo) error {
var merr *multierror.Error
for key, ruleInfo := range rules {
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
// On rollback error, add to rules map for next cleanup
r.rules[key] = ruleInfo.rule
}
}
if merr != nil {
r.updateState()
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
}
delete(r.rules, ruleKey+dnatSuffix)
}
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
}
delete(r.rules, ruleKey+snatSuffix)
}
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
}
delete(r.rules, ruleKey+fwdSuffix)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string var rule []string

View File

@@ -39,12 +39,14 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
}() }()
// Now 5 rules: // Now 5 rules:
// 1. established rule in forward chain // 1. established rule forward in
// 2. jump rule to NAT chain // 2. estbalished rule forward out
// 3. jump rule to PRE chain // 3. jump rule to POST nat chain
// 4. static outbound masquerade rule // 4. jump rule to PRE mangle chain
// 5. static return masquerade rule // 5. jump rule to PRE nat chain
require.Len(t, manager.rules, 5, "should have created rules map") // 6. static outbound masquerade rule
// 7. static return masquerade rule
require.Len(t, manager.rules, 7, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -332,14 +334,14 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map // Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()] rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map") assert.True(t, ok, "Rule not found in internal map")
// Log the internal rule // Log the internal rule
t.Logf("Internal rule: %v", rule) t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables // Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...) exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
assert.NoError(t, err, "Failed to check rule existence") assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables") assert.True(t, exists, "Rule not found in iptables")

View File

@@ -12,6 +12,6 @@ type Rule struct {
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) GetRuleID() string { func (r *Rule) ID() string {
return r.ruleID return r.ruleID
} }

View File

@@ -4,21 +4,20 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/device"
) )
type InterfaceState struct { type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
return i.NameStr return i.NameStr
} }
func (i *InterfaceState) Address() device.WGAddress { func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress return i.WGAddress
} }
@@ -62,7 +61,7 @@ func (s *ShutdownState) Cleanup() error {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore ipt.aclMgr.ipsetStore = s.ACLIPsetStore
} }
if err := ipt.Reset(nil); err != nil { if err := ipt.Close(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err) return fmt.Errorf("reset iptables manager: %w", err)
} }

View File

@@ -26,8 +26,8 @@ const (
// Each firewall type for different OS can use different type // Each firewall type for different OS can use different type
// of the properties to hold data of the created rule // of the properties to hold data of the created rule
type Rule interface { type Rule interface {
// GetRuleID returns the rule id // ID returns the rule id
GetRuleID() string ID() string
} }
// RuleDirection is the traffic direction which a rule is applied // RuleDirection is the traffic direction which a rule is applied
@@ -94,8 +94,8 @@ type Manager interface {
// SetLegacyManagement sets the legacy management mode // SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error SetLegacyManagement(legacy bool) error
// Reset firewall to the default state // Close closes the firewall manager
Reset(stateManager *statemanager.Manager) error Close(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller // Flush the changes to firewall controller
Flush() error Flush() error
@@ -105,6 +105,12 @@ type Manager interface {
EnableRouting() error EnableRouting() error
DisableRouting() error DisableRouting() error
// AddDNATRule adds a DNAT rule
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {

View File

@@ -0,0 +1,27 @@
package manager
import (
"fmt"
"net/netip"
)
// ForwardRule todo figure out better place to this to avoid circular imports
type ForwardRule struct {
Protocol Protocol
DestinationPort Port
TranslatedAddress netip.Addr
TranslatedPort Port
}
func (r ForwardRule) ID() string {
id := fmt.Sprintf("%s;%s;%s;%s",
r.Protocol,
r.DestinationPort.String(),
r.TranslatedAddress.String(),
r.TranslatedPort.String())
return id
}
func (r ForwardRule) String() string {
return fmt.Sprintf("protocol: %s, destinationPort: %s, translatedAddress: %s, translatedPort: %s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String())
}

View File

@@ -1,30 +1,12 @@
package manager package manager
import ( import (
"fmt"
"strconv" "strconv"
) )
// Protocol is the protocol of the port
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
// ProtocolUnknown unknown protocol
ProtocolUnknown Protocol = "unknown"
)
// Port of the address for firewall rule // Port of the address for firewall rule
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Port struct { type Port struct {
// IsRange is true Values contains two values, the first is the start port, the second is the end port // IsRange is true Values contains two values, the first is the start port, the second is the end port
IsRange bool IsRange bool
@@ -33,6 +15,25 @@ type Port struct {
Values []uint16 Values []uint16
} }
func NewPort(ports ...int) (*Port, error) {
if len(ports) == 0 {
return nil, fmt.Errorf("no port provided")
}
ports16 := make([]uint16, len(ports))
for i, port := range ports {
if port < 1 || port > 65535 {
return nil, fmt.Errorf("invalid port number: %d (must be between 1-65535)", port)
}
ports16[i] = uint16(port)
}
return &Port{
IsRange: len(ports) > 1,
Values: ports16,
}, nil
}
// String interface implementation // String interface implementation
func (p *Port) String() string { func (p *Port) String() string {
var ports string var ports string

View File

@@ -0,0 +1,19 @@
package manager
// Protocol is the protocol of the port
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
)

View File

@@ -127,7 +127,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
log.Errorf("failed to delete mangle rule: %v", err) log.Errorf("failed to delete mangle rule: %v", err)
} }
} }
delete(m.rules, r.GetRuleID()) delete(m.rules, r.ID())
return m.rConn.Flush() return m.rConn.Flush()
} }
@@ -141,7 +141,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
log.Errorf("failed to delete mangle rule: %v", err) log.Errorf("failed to delete mangle rule: %v", err)
} }
} }
delete(m.rules, r.GetRuleID()) delete(m.rules, r.ID())
return m.rConn.Flush() return m.rConn.Flush()
} }
@@ -176,7 +176,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
return err return err
} }
delete(m.rules, r.GetRuleID()) delete(m.rules, r.ID())
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name) m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) { if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -29,7 +29,7 @@ const (
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
} }
@@ -87,7 +87,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// We only need to record minimal interface state for potential recreation. // We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains // Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy // a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Reset() without needing to store specific rules. // cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{ if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{ InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(), NameStr: m.wgIface.Name(),
@@ -242,7 +242,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -342,6 +342,22 @@ func (m *Manager) Flush() error {
return m.aclManager.Flush() return m.aclManager.Flush()
} }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {

View File

@@ -16,15 +16,15 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"), IP: net.ParseIP("100.96.0.0"),
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
AddressFunc func() iface.WGAddress AddressFunc func() wgaddr.Address
} }
func (i *iFaceMock) Name() string { func (i *iFaceMock) Name() string {
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set") panic("NameFunc is not set")
} }
func (i *iFaceMock) Address() iface.WGAddress { func (i *iFaceMock) Address() wgaddr.Address {
if i.AddressFunc != nil { if i.AddressFunc != nil {
return i.AddressFunc() return i.AddressFunc()
} }
@@ -65,7 +65,7 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
err = manager.Reset(nil) err = manager.Close(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@@ -162,7 +162,7 @@ func TestNftablesManager(t *testing.T) {
// established rule remains // established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion") require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset(nil) err = manager.Close(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
} }
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"), IP: net.ParseIP("100.96.0.0"),
@@ -191,7 +191,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
if err := manager.Reset(nil); err != nil { if err := manager.Close(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -274,7 +274,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
t.Cleanup(func() { t.Cleanup(func() {
err := manager.Reset(nil) err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state") require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset // Verify iptables output after reset

View File

@@ -14,23 +14,31 @@ import (
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
chainNameRoutingFw = "netbird-rt-fwd" tableNat = "nat"
chainNameRoutingNat = "netbird-rt-postrouting" chainNameNatPrerouting = "PREROUTING"
chainNameForward = "FORWARD" chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif" userDataAcceptForwardRuleOif = "frwacceptoif"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
) )
const refreshRulesMapError = "refresh rules map: %w" const refreshRulesMapError = "refresh rules map: %w"
@@ -49,16 +57,18 @@ type router struct {
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool legacyManagement bool
} }
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
r := &router{ r := &router{
conn: &nftables.Conn{}, conn: &nftables.Conn{},
workTable: workTable, workTable: workTable,
chains: make(map[string]*nftables.Chain), chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule), rules: make(map[string]*nftables.Rule),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
@@ -98,7 +108,52 @@ func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller // clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear() r.ipsetCounter.Clear()
return r.removeAcceptForwardRules() var merr *multierror.Error
if err := r.removeAcceptForwardRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatPreroutingRules() error {
table := &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
}
chain := &nftables.Chain{
Name: chainNameNatPrerouting,
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
}
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from nat table: %w", err)
}
var merr *multierror.Error
// Delete rules that have our UserData suffix
for _, rule := range rules {
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), dnatSuffix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
return nberrors.FormatErrorOrNil(merr)
} }
func (r *router) loadFilterTable() (*nftables.Table, error) { func (r *router) loadFilterTable() (*nftables.Table, error) {
@@ -133,14 +188,22 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingRdr,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
// Chain is created by acl manager // Chain is created by acl manager
// TODO: move creation to a common place // TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{ r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting, Name: chainNamePrerouting,
Table: r.workTable, Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting, Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle, Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
} }
// Add the single NAT rule that matches on mark // Add the single NAT rule that matches on mark
@@ -281,7 +344,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
ruleKey := rule.GetRuleID() ruleKey := rule.ID()
nftRule, exists := r.rules[ruleKey] nftRule, exists := r.rules[ruleKey]
if !exists { if !exists {
log.Debugf("route rule %s not found", ruleKey) log.Debugf("route rule %s not found", ruleKey)
@@ -410,6 +473,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain // AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@@ -836,6 +903,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@@ -896,6 +967,269 @@ func (r *router) refreshRulesMap() error {
return nil return nil
} }
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
protoNum, err := protoToInt(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
return nil, err
}
r.addDnatMasq(rule, protoNum, ruleKey)
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
// TODO: find chains with drop policies and add rules there
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush rules: %w", err)
}
return &rule, nil
}
func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey string) error {
dnatExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
}
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
// shifted translated port is not supported in nftables, so we hand this over to xtables
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
return r.addXTablesRedirect(dnatExprs, ruleKey, rule)
}
}
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
if err != nil {
return err
}
dnatExprs = append(dnatExprs, additionalExprs...)
dnatExprs = append(dnatExprs,
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *router) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
switch {
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
return r.handlePortRange(rule)
case len(rule.TranslatedPort.Values) == 0:
return r.handleAddressOnly(rule)
case len(rule.TranslatedPort.Values) == 1:
return r.handleSinglePort(rule)
default:
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
}
func (r *router) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
&expr.Immediate{
Register: 3,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
},
}
return exprs, 2, 3, nil
}
func (r *router) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
return exprs, 0, 0, nil
}
func (r *router) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
}
return exprs, 2, 0, nil
}
func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule firewall.ForwardRule) error {
dnatExprs = append(dnatExprs,
&expr.Counter{},
&expr.Target{
Name: "DNAT",
Rev: 2,
Info: &xt.NatRange2{
NatRange: xt.NatRange{
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
MinIP: rule.TranslatedAddress.AsSlice(),
MaxIP: rule.TranslatedAddress.AsSlice(),
MinPort: rule.TranslatedPort.Values[0],
MaxPort: rule.TranslatedPort.Values[1],
},
BasePort: rule.DestinationPort.Values[0],
},
},
)
dnatRule := &nftables.Rule{
Table: &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
},
Chain: &nftables.Chain{
Name: chainNameNatPrerouting,
Table: r.filterTable,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
},
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey string) {
masqExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
masqExprs = append(masqExprs, &expr.Masq{})
masqRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: masqExprs,
UserData: []byte(ruleKey + snatSuffix),
}
r.conn.AddRule(masqRule)
r.rules[ruleKey+snatSuffix] = masqRule
}
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if merr == nil {
delete(r.rules, ruleKey+dnatSuffix)
delete(r.rules, ruleKey+snatSuffix)
}
return nberrors.FormatErrorOrNil(merr)
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR // generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32 var offset uint32
@@ -959,15 +1293,11 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port.IsRange && len(port.Values) == 2 { if port.IsRange && len(port.Values) == 2 {
// Handle port range // Handle port range
exprs = append(exprs, exprs = append(exprs,
&expr.Cmp{ &expr.Range{
Op: expr.CmpOpGte, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: binaryutil.BigEndian.PutUint16(port.Values[0]), FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
}, ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
&expr.Cmp{
Op: expr.CmpOpLte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
}, },
) )
} else { } else {

View File

@@ -38,7 +38,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present // need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}) })
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -127,7 +127,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}) })
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -319,7 +319,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
}) })
// Check if the rule is in the internal map // Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()] rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map") assert.True(t, ok, "Rule not found in internal map")
t.Log("Internal rule expressions:") t.Log("Internal rule expressions:")
@@ -336,7 +336,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
var nftRule *nftables.Rule var nftRule *nftables.Rule
for _, rule := range rules { for _, rule := range rules {
if string(rule.UserData) == ruleKey.GetRuleID() { if string(rule.UserData) == ruleKey.ID() {
nftRule = rule nftRule = rule
break break
} }
@@ -595,16 +595,20 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 { if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
payloadFound = true payloadFound = true
} }
case *expr.Cmp: case *expr.Range:
if port.IsRange { if port.IsRange && len(port.Values) == 2 {
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte { fromPort := binary.BigEndian.Uint16(ex.FromData)
toPort := binary.BigEndian.Uint16(ex.ToData)
if fromPort == port.Values[0] && toPort == port.Values[1] {
portMatchFound = true portMatchFound = true
} }
} else { }
case *expr.Cmp:
if !port.IsRange {
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 { if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
portValue := binary.BigEndian.Uint16(ex.Data) portValue := binary.BigEndian.Uint16(ex.Data)
for _, p := range port.Values { for _, p := range port.Values {
if uint16(p) == portValue { if p == portValue {
portMatchFound = true portMatchFound = true
break break
} }

View File

@@ -16,6 +16,6 @@ type Rule struct {
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) GetRuleID() string { func (r *Rule) ID() string {
return r.ruleID return r.ruleID
} }

View File

@@ -3,21 +3,20 @@ package nftables
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/device"
) )
type InterfaceState struct { type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
return i.NameStr return i.NameStr
} }
func (i *InterfaceState) Address() device.WGAddress { func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress return i.WGAddress
} }
@@ -39,7 +38,7 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create nftables manager: %w", err) return fmt.Errorf("create nftables manager: %w", err)
} }
if err := nft.Reset(nil); err != nil { if err := nft.Close(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err) return fmt.Errorf("reset nftables manager: %w", err)
} }

View File

@@ -8,12 +8,11 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -22,17 +21,14 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
} }
if m.forwarder != nil { if m.forwarder != nil {
@@ -48,7 +44,7 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
} }
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager) return m.nativeFirewall.Close(stateManager)
} }
return nil return nil
} }

View File

@@ -9,7 +9,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -21,8 +20,8 @@ const (
firewallRuleName = "Netbird" firewallRuleName = "Netbird"
) )
// Reset firewall to the default state // Close closes the firewall manager
func (m *Manager) Reset(*statemanager.Manager) error { func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -31,17 +30,14 @@ func (m *Manager) Reset(*statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
} }
if m.forwarder != nil { if m.forwarder != nil {

View File

@@ -3,14 +3,14 @@ package common
import ( import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
Address() iface.WGAddress Address() wgaddr.Address
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
} }

View File

@@ -1,6 +1,7 @@
package conntrack package conntrack
import ( import (
"context"
"net" "net"
"sync" "sync"
"time" "time"
@@ -39,8 +40,8 @@ type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs ipPool *PreallocatedIPs
} }
@@ -50,16 +51,18 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
timeout = DefaultICMPTimeout timeout = DefaultICMPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &ICMPTracker{ tracker := &ICMPTracker{
logger: logger, logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack), connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
ipPool: NewPreallocatedIPs(), ipPool: NewPreallocatedIPs(),
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
@@ -119,12 +122,14 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
conn.Sequence == seq conn.Sequence == seq
} }
func (t *ICMPTracker) cleanupRoutine() { func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
defer t.tickerCancel()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@@ -146,8 +151,7 @@ func (t *ICMPTracker) cleanup() {
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() { func (t *ICMPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections { for _, conn := range t.connections {

View File

@@ -3,6 +3,7 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections // TODO: Send RST packets for invalid/timed-out connections
import ( import (
"context"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -85,23 +86,26 @@ type TCPTracker struct {
connections map[ConnKey]*TCPConnTrack connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex mutex sync.RWMutex
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
done chan struct{} tickerCancel context.CancelFunc
timeout time.Duration timeout time.Duration
ipPool *PreallocatedIPs ipPool *PreallocatedIPs
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker { func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{ tracker := &TCPTracker{
logger: logger, logger: logger,
connections: make(map[ConnKey]*TCPConnTrack), connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
timeout: timeout, timeout: timeout,
ipPool: NewPreallocatedIPs(), ipPool: NewPreallocatedIPs(),
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
@@ -315,12 +319,14 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return false return false
} }
func (t *TCPTracker) cleanupRoutine() { func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@@ -355,8 +361,7 @@ func (t *TCPTracker) cleanup() {
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() { func (t *TCPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
// Clean up all remaining IPs // Clean up all remaining IPs
t.mutex.Lock() t.mutex.Lock()

View File

@@ -1,6 +1,7 @@
package conntrack package conntrack
import ( import (
"context"
"net" "net"
"sync" "sync"
"time" "time"
@@ -26,8 +27,8 @@ type UDPTracker struct {
connections map[ConnKey]*UDPConnTrack connections map[ConnKey]*UDPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs ipPool *PreallocatedIPs
} }
@@ -37,16 +38,18 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
timeout = DefaultUDPTimeout timeout = DefaultUDPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &UDPTracker{ tracker := &UDPTracker{
logger: logger, logger: logger,
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
ipPool: NewPreallocatedIPs(), ipPool: NewPreallocatedIPs(),
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
@@ -103,12 +106,14 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
} }
// cleanupRoutine periodically removes stale connections // cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine() { func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@@ -131,8 +136,7 @@ func (t *UDPTracker) cleanup() {
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() { func (t *UDPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections { for _, conn := range t.connections {

View File

@@ -1,6 +1,7 @@
package conntrack package conntrack
import ( import (
"context"
"net" "net"
"testing" "testing"
"time" "time"
@@ -34,7 +35,7 @@ func TestNewUDPTracker(t *testing.T) {
assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker) assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.done) assert.NotNil(t, tracker.tickerCancel)
}) })
} }
} }
@@ -154,18 +155,21 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout := 50 * time.Millisecond timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond cleanupInterval := 25 * time.Millisecond
ctx, tickerCancel := context.WithCancel(context.Background())
defer tickerCancel()
// Create tracker with custom cleanup interval // Create tracker with custom cleanup interval
tracker := &UDPTracker{ tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval), cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}), tickerCancel: tickerCancel,
ipPool: NewPreallocatedIPs(), ipPool: NewPreallocatedIPs(),
logger: logger, logger: logger,
} }
// Start cleanup routine // Start cleanup routine
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
// Add some connections // Add some connections
connections := []struct { connections := []struct {

View File

@@ -6,19 +6,19 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
func TestLocalIPManager(t *testing.T) { func TestLocalIPManager(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
setupAddr iface.WGAddress setupAddr wgaddr.Address
testIP net.IP testIP net.IP
expected bool expected bool
}{ }{
{ {
name: "Localhost range", name: "Localhost range",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
@@ -30,7 +30,7 @@ func TestLocalIPManager(t *testing.T) {
}, },
{ {
name: "Localhost standard address", name: "Localhost standard address",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
@@ -42,7 +42,7 @@ func TestLocalIPManager(t *testing.T) {
}, },
{ {
name: "Localhost range edge", name: "Localhost range edge",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
@@ -54,7 +54,7 @@ func TestLocalIPManager(t *testing.T) {
}, },
{ {
name: "Local IP matches", name: "Local IP matches",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
@@ -66,7 +66,7 @@ func TestLocalIPManager(t *testing.T) {
}, },
{ {
name: "Local IP doesn't match", name: "Local IP doesn't match",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
@@ -78,7 +78,7 @@ func TestLocalIPManager(t *testing.T) {
}, },
{ {
name: "IPv6 address", name: "IPv6 address",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"), IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("fe80::"), IP: net.ParseIP("fe80::"),
@@ -95,7 +95,7 @@ func TestLocalIPManager(t *testing.T) {
manager := newLocalIPManager() manager := newLocalIPManager()
mock := &IFaceMock{ mock := &IFaceMock{
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return tt.setupAddr return tt.setupAddr
}, },
} }

View File

@@ -24,8 +24,8 @@ type PeerRule struct {
udpHook func([]byte) bool udpHook func([]byte) bool
} }
// GetRuleID returns the rule id // ID returns the rule id
func (r *PeerRule) GetRuleID() string { func (r *PeerRule) ID() string {
return r.id return r.id
} }
@@ -39,7 +39,7 @@ type RouteRule struct {
action firewall.Action action firewall.Action
} }
// GetRuleID returns the rule id // ID returns the rule id
func (r *RouteRule) GetRuleID() string { func (r *RouteRule) ID() string {
return r.id return r.id
} }

View File

@@ -42,6 +42,8 @@ const (
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING" EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
) )
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// RuleSet is a set of rules grouped by a string key // RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule type RuleSet map[string]PeerRule
@@ -437,7 +439,7 @@ func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
ruleID := rule.GetRuleID() ruleID := rule.ID()
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool { idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
return r.id == ruleID return r.id == ruleID
}) })
@@ -478,6 +480,22 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// DropOutgoing filter outgoing packets // DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool { func (m *Manager) DropOutgoing(packetData []byte) bool {
return m.processOutgoingHooks(packetData) return m.processOutgoingHooks(packetData)

View File

@@ -160,7 +160,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -205,7 +205,7 @@ func BenchmarkStateScaling(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -253,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -452,7 +452,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
// Setup scenario // Setup scenario
@@ -579,7 +579,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -670,7 +670,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -789,7 +789,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -877,7 +877,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{

View File

@@ -12,9 +12,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
@@ -26,8 +26,8 @@ func TestPeerACLFiltering(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: localIP, IP: localIP,
Network: wgNet, Network: wgNet,
} }
@@ -39,7 +39,7 @@ func TestPeerACLFiltering(t *testing.T) {
require.NotNil(t, manager) require.NotNil(t, manager)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}) })
manager.wgNetwork = wgNet manager.wgNetwork = wgNet
@@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: localIP, IP: localIP,
Network: wgNet, Network: wgNet,
} }
@@ -310,7 +310,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
require.False(tb, manager.nativeRouter) require.False(tb, manager.nativeRouter)
tb.Cleanup(func() { tb.Cleanup(func() {
require.NoError(tb, manager.Reset(nil)) require.NoError(tb, manager.Close(nil))
}) })
return manager return manager

View File

@@ -16,15 +16,15 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress AddressFunc func() wgaddr.Address
GetWGDeviceFunc func() *wgdevice.Device GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice GetDeviceFunc func() *device.FilteredDevice
} }
@@ -50,9 +50,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
return i.SetFilterFunc(iface) return i.SetFilterFunc(iface)
} }
func (i *IFaceMock) Address() iface.WGAddress { func (i *IFaceMock) Address() wgaddr.Address {
if i.AddressFunc == nil { if i.AddressFunc == nil {
return iface.WGAddress{} return wgaddr.Address{}
} }
return i.AddressFunc() return i.AddressFunc()
} }
@@ -135,7 +135,7 @@ func TestManagerDeleteRule(t *testing.T) {
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok { if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -149,7 +149,7 @@ func TestManagerDeleteRule(t *testing.T) {
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok { if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -254,7 +254,7 @@ func TestManagerReset(t *testing.T) {
return return
} }
err = m.Reset(nil) err = m.Close(nil)
if err != nil { if err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
@@ -268,8 +268,8 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) { func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"), IP: net.ParseIP("100.10.0.0"),
@@ -333,7 +333,7 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if err = m.Reset(nil); err != nil { if err = m.Close(nil); err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
} }
@@ -352,7 +352,7 @@ func TestRemovePacketHook(t *testing.T) {
t.Fatalf("Failed to create Manager: %s", err) t.Fatalf("Failed to create Manager: %s", err)
} }
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}() }()
// Add a UDP packet hook // Add a UDP packet hook
@@ -403,7 +403,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}() }()
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
@@ -484,7 +484,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Reset(nil); err != nil { if err := manager.Close(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -530,7 +530,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, },
} }
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}() }()
// Set up packet parameters // Set up packet parameters

View File

@@ -5,7 +5,6 @@ import (
"net" "net"
"net/netip" "net/netip"
"runtime" "runtime"
"strings"
"sync" "sync"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
@@ -14,6 +13,8 @@ import (
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type RecvMessage struct { type RecvMessage struct {
@@ -52,9 +53,10 @@ type ICEBind struct {
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault udpMux *UniversalUDPMuxDefault
address wgaddr.Address
} }
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
@@ -64,6 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}), closedChan: make(chan struct{}),
closed: true, closed: true,
address: address,
} }
rc := receiverCreator{ rc := receiverCreator{
@@ -108,35 +111,17 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil return s.udpMux, nil
} }
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) { func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock() b.endpointsMu.Lock()
b.endpoints[fakeAddr] = conn b.endpoints[fakeIP] = conn
b.endpointsMu.Unlock() b.endpointsMu.Unlock()
return fakeUDPAddr, nil
} }
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) { func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock() b.endpointsMu.Lock()
defer b.endpointsMu.Unlock() defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeAddr)
delete(b.endpoints, fakeIP)
} }
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
@@ -161,9 +146,10 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: conn, UDPConn: conn,
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
WGAddress: s.address,
}, },
) )
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
@@ -275,21 +261,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
} }
} }
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message { func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message) return msgsPool.Get().(*[]ipv6.Message)
} }

View File

@@ -17,6 +17,8 @@ import (
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// FilterFn is a function that filters out candidates based on the address. // FilterFn is a function that filters out candidates based on the address.
@@ -41,6 +43,7 @@ type UniversalUDPMuxParams struct {
XORMappedAddrCacheTTL time.Duration XORMappedAddrCacheTTL time.Duration
Net transport.Net Net transport.Net
FilterFn FilterFn FilterFn FilterFn
WGAddress wgaddr.Address
} }
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -64,6 +67,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
mux: m, mux: m,
logger: params.Logger, logger: params.Logger,
filterFn: params.FilterFn, filterFn: params.FilterFn,
address: params.WGAddress,
} }
// embed UDPMux // embed UDPMux
@@ -118,6 +122,7 @@ type udpConn struct {
filterFn FilterFn filterFn FilterFn
// TODO: reset cache on route changes // TODO: reset cache on route changes
addrCache sync.Map addrCache sync.Map
address wgaddr.Address
} }
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
@@ -159,6 +164,11 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil return nil
} }
if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil { if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err) log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else { } else {

View File

@@ -9,13 +9,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create() (device.WGConfigurer, error) Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error UpdateAddr(address wgaddr.Address) error
WgAddress() WGAddress WgAddress() wgaddr.Address
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice

View File

@@ -13,11 +13,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct { type WGTunDevice struct {
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -31,7 +32,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@@ -93,7 +94,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error { func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
// todo implement // todo implement
return nil return nil
} }
@@ -123,7 +124,7 @@ func (t *WGTunDevice) DeviceName() string {
return t.name return t.name
} }
func (t *WGTunDevice) WgAddress() WGAddress { func (t *WGTunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -13,11 +13,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type TunDevice struct { type TunDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -29,7 +30,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -85,7 +86,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunDevice) UpdateAddr(address WGAddress) error { func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -106,7 +107,7 @@ func (t *TunDevice) Close() error {
return nil return nil
} }
func (t *TunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -14,11 +14,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type TunDevice struct { type TunDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
iceBind *bind.ICEBind iceBind *bind.ICEBind
@@ -30,7 +31,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -120,11 +121,11 @@ func (t *TunDevice) Close() error {
return nil return nil
} }
func (t *TunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunDevice) UpdateAddr(addr WGAddress) error { func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
// todo implement // todo implement
return nil return nil
} }

View File

@@ -14,12 +14,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock" "github.com/netbirdio/netbird/sharedsock"
) )
type TunKernelDevice struct { type TunKernelDevice struct {
name string name string
address WGAddress address wgaddr.Address
wgPort int wgPort int
key string key string
mtu int mtu int
@@ -34,7 +35,7 @@ type TunKernelDevice struct {
filterFn bind.FilterFn filterFn bind.FilterFn
} }
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{ return &TunKernelDevice{
ctx: ctx, ctx: ctx,
@@ -99,9 +100,10 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
bindParams := bind.UniversalUDPMuxParams{ bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock, UDPConn: rawSock,
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn, FilterFn: t.filterFn,
WGAddress: t.address,
} }
mux := bind.NewUniversalUDPMuxDefault(bindParams) mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx) go mux.ReadFromConn(t.ctx)
@@ -112,7 +114,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil return t.udpMux, nil
} }
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error { func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -145,7 +147,7 @@ func (t *TunKernelDevice) Close() error {
return closErr return closErr
} }
func (t *TunKernelDevice) WgAddress() WGAddress { func (t *TunKernelDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -13,12 +13,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
type TunNetstackDevice struct { type TunNetstackDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -34,7 +35,7 @@ type TunNetstackDevice struct {
net *netstack.Net net *netstack.Net
} }
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{ return &TunNetstackDevice{
name: name, name: name,
address: address, address: address,
@@ -97,7 +98,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error { func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
return nil return nil
} }
@@ -116,7 +117,7 @@ func (t *TunNetstackDevice) Close() error {
return nil return nil
} }
func (t *TunNetstackDevice) WgAddress() WGAddress { func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -12,11 +12,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type USPDevice struct { type USPDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -28,7 +29,7 @@ type USPDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode") log.Infof("using userspace bind mode")
return &USPDevice{ return &USPDevice{
@@ -93,7 +94,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *USPDevice) UpdateAddr(address WGAddress) error { func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -113,7 +114,7 @@ func (t *USPDevice) Close() error {
return nil return nil
} }
func (t *USPDevice) WgAddress() WGAddress { func (t *USPDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -13,13 +13,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}" const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
type TunDevice struct { type TunDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -32,7 +33,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -118,7 +119,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunDevice) UpdateAddr(address WGAddress) error { func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -139,7 +140,7 @@ func (t *TunDevice) Close() error {
} }
return nil return nil
} }
func (t *TunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd" "github.com/netbirdio/netbird/client/iface/freebsd"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type wgLink struct { type wgLink struct {
@@ -56,7 +57,7 @@ func (l *wgLink) up() error {
return nil return nil
} }
func (l *wgLink) assignAddr(address WGAddress) error { func (l *wgLink) assignAddr(address wgaddr.Address) error {
link, err := freebsd.LinkByName(l.name) link, err := freebsd.LinkByName(l.name)
if err != nil { if err != nil {
return fmt.Errorf("link by name: %w", err) return fmt.Errorf("link by name: %w", err)

View File

@@ -8,6 +8,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type wgLink struct { type wgLink struct {
@@ -90,7 +92,7 @@ func (l *wgLink) up() error {
return nil return nil
} }
func (l *wgLink) assignAddr(address WGAddress) error { func (l *wgLink) assignAddr(address wgaddr.Address) error {
//delete existing addresses //delete existing addresses
list, err := netlink.AddrList(l, 0) list, err := netlink.AddrList(l, 0)
if err != nil { if err != nil {

View File

@@ -7,13 +7,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error UpdateAddr(address wgaddr.Address) error
WgAddress() WGAddress WgAddress() wgaddr.Address
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
@@ -28,8 +29,6 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault WgInterfaceDefault = configurer.WgInterfaceDefault
) )
type WGAddress = device.WGAddress
type wgProxyFactory interface { type wgProxyFactory interface {
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Free() error Free() error
@@ -72,7 +71,7 @@ func (w *WGIface) Name() string {
} }
// Address returns the interface address // Address returns the interface address
func (w *WGIface) Address() device.WGAddress { func (w *WGIface) Address() wgaddr.Address {
return w.tun.WgAddress() return w.tun.WgAddress()
} }
@@ -103,7 +102,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
addr, err := device.ParseWGAddress(newAddr) addr, err := wgaddr.ParseWGAddress(newAddr)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -3,17 +3,18 @@ package iface
import ( import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,

View File

@@ -6,17 +6,18 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {

View File

@@ -5,17 +5,18 @@ package iface
import ( import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace := &WGIface{ wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd), tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),

View File

@@ -8,12 +8,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -21,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{} wgIFace := &WGIface{}
if netstack.IsEnabled() { if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
@@ -34,7 +35,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return wgIFace, nil return wgIFace, nil
} }
if device.ModuleTunIsLoaded() { if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)

View File

@@ -4,16 +4,17 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {

View File

@@ -55,7 +55,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy)) skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
if err != nil { if err != nil {
log.Errorf("failed to parse NB_ETSTACK_SKIP_PROXY: %s", err) log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
} }
if skipProxy { if skipProxy {
return nsTunDev, tunNet, nil return nsTunDev, tunNet, nil

View File

@@ -1,29 +1,29 @@
package device package wgaddr
import ( import (
"fmt" "fmt"
"net" "net"
) )
// WGAddress WireGuard parsed address // Address WireGuard parsed address
type WGAddress struct { type Address struct {
IP net.IP IP net.IP
Network *net.IPNet Network *net.IPNet
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (WGAddress, error) { func ParseWGAddress(address string) (Address, error) {
ip, network, err := net.ParseCIDR(address) ip, network, err := net.ParseCIDR(address)
if err != nil { if err != nil {
return WGAddress{}, err return Address{}, err
} }
return WGAddress{ return Address{
IP: ip, IP: ip,
Network: network, Network: network,
}, nil }, nil
} }
func (addr WGAddress) String() string { func (addr Address) String() string {
maskSize, _ := addr.Network.Mask.Size() maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -16,13 +17,13 @@ import (
type ProxyBind struct { type ProxyBind struct {
Bind *bind.ICEBind Bind *bind.ICEBind
wgAddr *net.UDPAddr fakeNetIP *netip.AddrPort
wgEndpoint *bind.Endpoint wgBindEndpoint *bind.Endpoint
remoteConn net.Conn remoteConn net.Conn
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
closeMu sync.Mutex closeMu sync.Mutex
closed bool closed bool
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
@@ -33,20 +34,24 @@ type ProxyBind struct {
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration. // WireGuard configuration.
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn) fakeNetIP, err := fakeAddress(nbAddr)
if err != nil { if err != nil {
return err return err
} }
p.wgAddr = addr p.fakeNetIP = fakeNetIP
p.wgEndpoint = addrToEndpoint(addr) p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
return err return nil
} }
func (p *ProxyBind) EndpointAddr() *net.UDPAddr { func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return p.wgAddr return &net.UDPAddr{
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
} }
func (p *ProxyBind) Work() { func (p *ProxyBind) Work() {
@@ -54,6 +59,8 @@ func (p *ProxyBind) Work() {
return return
} }
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
p.pausedMu.Lock() p.pausedMu.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock() p.pausedMu.Unlock()
@@ -93,7 +100,7 @@ func (p *ProxyBind) close() error {
p.cancel() p.cancel()
p.Bind.RemoveEndpoint(p.wgAddr) p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr return rErr
@@ -126,7 +133,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
} }
msg := bind.RecvMessage{ msg := bind.RecvMessage{
Endpoint: p.wgEndpoint, Endpoint: p.wgBindEndpoint,
Buffer: buf[:n], Buffer: buf[:n],
} }
p.Bind.RecvChan <- msg p.Bind.RecvChan <- msg
@@ -134,8 +141,19 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
} }
} }
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { // fakeAddress returns a fake address that is used to as an identifier for the peer.
ip, _ := netip.AddrFromSlice(addr.IP.To4()) // The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
return &bind.Endpoint{AddrPort: addrPort} octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
if err != nil {
return nil, fmt.Errorf("failed to parse new IP: %w", err)
}
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil
} }

View File

@@ -6,8 +6,8 @@
!define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network" !define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network"
!define INSTALLER_NAME "netbird-installer.exe" !define INSTALLER_NAME "netbird-installer.exe"
!define MAIN_APP_EXE "Netbird" !define MAIN_APP_EXE "Netbird"
!define ICON "ui\\netbird.ico" !define ICON "ui\\assets\\netbird.ico"
!define BANNER "ui\\banner.bmp" !define BANNER "ui\\build\\banner.bmp"
!define LICENSE_DATA "..\\LICENSE" !define LICENSE_DATA "..\\LICENSE"
!define INSTALL_DIR "$PROGRAMFILES64\${APP_NAME}" !define INSTALL_DIR "$PROGRAMFILES64\${APP_NAME}"
@@ -22,6 +22,8 @@
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}" !define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}" !define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
Unicode True Unicode True
###################################################################### ######################################################################
@@ -68,6 +70,9 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY !insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES !insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH !insertmacro MUI_PAGE_FINISH
@@ -80,8 +85,36 @@ ShowInstDetails Show
!insertmacro MUI_LANGUAGE "English" !insertmacro MUI_LANGUAGE "English"
; Variables for autostart option
Var AutostartCheckbox
Var AutostartEnabled
###################################################################### ######################################################################
; Function to create the autostart options page
Function AutostartPage
!insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked
StrCpy $AutostartEnabled "1" ; Default to enabled
nsDialogs::Show
FunctionEnd
; Function to handle leaving the autostart page
Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd
Function GetAppFromCommand Function GetAppFromCommand
Exch $1 Exch $1
Push $2 Push $2
@@ -163,6 +196,16 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}" WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
EnVar::SetHKLM EnVar::SetHKLM
EnVar::AddValueEx "path" "$INSTDIR" EnVar::AddValueEx "path" "$INSTDIR"
@@ -186,7 +229,10 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client # kill ui client
ExecWait `taskkill /im ${UI_APP_EXE}.exe` ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
Sleep 3000 Sleep 3000

View File

@@ -12,7 +12,7 @@ import (
type RuleID string type RuleID string
func (r RuleID) GetRuleID() string { func (r RuleID) ID() string {
return string(r) return string(r)
} }

View File

@@ -245,7 +245,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
return "", fmt.Errorf("add route rule: %w", err) return "", fmt.Errorf("add route rule: %w", err)
} }
return id.RuleID(addedRule.GetRuleID()), nil return id.RuleID(addedRule.ID()), nil
} }
func (d *DefaultManager) protoRuleToFirewallRule( func (d *DefaultManager) protoRuleToFirewallRule(
@@ -515,7 +515,7 @@ func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
for _, rules := range newRulePairs { for _, rules := range newRulePairs {
for _, rule := range rules { for _, rule := range rules {
if err := d.firewall.DeletePeerRule(rule); err != nil { if err := d.firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err) log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err)
} }
} }
} }

View File

@@ -8,7 +8,7 @@ import (
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
@@ -45,7 +45,7 @@ func TestDefaultManager(t *testing.T) {
} }
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
@@ -58,7 +58,7 @@ func TestDefaultManager(t *testing.T) {
return return
} }
defer func(fw manager.Manager) { defer func(fw manager.Manager) {
_ = fw.Reset(nil) _ = fw.Close(nil)
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
@@ -74,7 +74,7 @@ func TestDefaultManager(t *testing.T) {
t.Run("add extra rules", func(t *testing.T) { t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[string]struct{}{} existedPairs := map[string]struct{}{}
for id := range acl.peerRulesPairs { for id := range acl.peerRulesPairs {
existedPairs[id.GetRuleID()] = struct{}{} existedPairs[id.ID()] = struct{}{}
} }
// remove first rule // remove first rule
@@ -100,7 +100,7 @@ func TestDefaultManager(t *testing.T) {
// check that old rule was removed // check that old rule was removed
previousCount := 0 previousCount := 0
for id := range acl.peerRulesPairs { for id := range acl.peerRulesPairs {
if _, ok := existedPairs[id.GetRuleID()]; ok { if _, ok := existedPairs[id.ID()]; ok {
previousCount++ previousCount++
} }
} }
@@ -339,7 +339,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
} }
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
@@ -352,7 +352,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
return return
} }
defer func(fw manager.Manager) { defer func(fw manager.Manager) {
_ = fw.Reset(nil) _ = fw.Close(nil)
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)

View File

@@ -10,8 +10,8 @@ import (
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
iface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// MockIFaceMapper is a mock of IFaceMapper interface. // MockIFaceMapper is a mock of IFaceMapper interface.
@@ -38,10 +38,10 @@ func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
} }
// Address mocks base method. // Address mocks base method.
func (m *MockIFaceMapper) Address() iface.WGAddress { func (m *MockIFaceMapper) Address() wgaddr.Address {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Address") ret := m.ctrl.Call(m, "Address")
ret0, _ := ret[0].(iface.WGAddress) ret0, _ := ret[0].(wgaddr.Address)
return ret0 return ret0
} }

View File

@@ -61,7 +61,7 @@ func NewConnectClient(
} }
// Run with main logic. // Run with main logic.
func (c *ConnectClient) Run(runningChan chan error) error { func (c *ConnectClient) Run(runningChan chan struct{}) error {
return c.run(MobileDependency{}, runningChan) return c.run(MobileDependency{}, runningChan)
} }
@@ -102,7 +102,7 @@ func (c *ConnectClient) RunOniOS(
return c.run(mobileDependency, nil) return c.run(mobileDependency, nil)
} }
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error { func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
rec := c.statusRecorder rec := c.statusRecorder
@@ -159,10 +159,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
} }
defer c.statusRecorder.ClientStop() defer c.statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error { operation := func() error {
// if context cancelled we not start new backoff cycle // if context cancelled we not start new backoff cycle
if c.isContextCancelled() { if c.ctx.Err() != nil {
return nil return nil
} }
@@ -282,10 +281,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil && runningChanOpen { if runningChan != nil {
runningChan <- nil select {
close(runningChan) case runningChan <- struct{}{}:
runningChanOpen = false default:
}
} }
<-engineCtx.Done() <-engineCtx.Done()
@@ -379,15 +379,6 @@ func (c *ConnectClient) Stop() error {
return nil return nil
} }
func (c *ConnectClient) isContextCancelled() bool {
select {
case <-c.ctx.Done():
return true
default:
return false
}
}
// SetNetworkMapPersistence enables or disables network map persistence. // SetNetworkMapPersistence enables or disables network map persistence.
// When enabled, the last received network map will be stored and can be retrieved // When enabled, the last received network map will be stored and can be retrieved
// through the Engine's getLatestNetworkMap method. When disabled, any stored // through the Engine's getLatestNetworkMap method. When disabled, any stored

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks" pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@@ -37,9 +38,9 @@ func (w *mocWGIface) Name() string {
panic("implement me") panic("implement me")
} }
func (w *mocWGIface) Address() iface.WGAddress { func (w *mocWGIface) Address() wgaddr.Address {
ip, network, _ := net.ParseCIDR("100.66.100.0/24") ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return iface.WGAddress{ return wgaddr.Address{
IP: ip, IP: ip,
Network: network, Network: network,
} }
@@ -1015,7 +1016,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
mh.AssertExpectations(t) mh.AssertExpectations(t)
} }
// Reset mocks // Close mocks
if mh, ok := tc.expectedHandler.(*MockHandler); ok { if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.ExpectedCalls = nil mh.ExpectedCalls = nil
mh.Calls = nil mh.Calls = nil

View File

@@ -5,15 +5,15 @@ package dns
import ( import (
"net" "net"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGIface defines subset methods of interface required for manager // WGIface defines subset methods of interface required for manager
type WGIface interface { type WGIface interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
ToInterface() *net.Interface ToInterface() *net.Interface
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter

View File

@@ -1,15 +1,15 @@
package dns package dns
import ( import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGIface defines subset methods of interface required for manager // WGIface defines subset methods of interface required for manager
type WGIface interface { type WGIface interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice

View File

@@ -25,7 +25,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
@@ -33,6 +33,7 @@ import (
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
@@ -169,10 +170,11 @@ type Engine struct {
statusRecorder *peer.Status statusRecorder *peer.Status
firewall manager.Manager firewall firewallManager.Manager
routeManager routemanager.Manager routeManager routemanager.Manager
acl acl.Manager acl acl.Manager
dnsForwardMgr *dnsfwd.Manager dnsForwardMgr *dnsfwd.Manager
ingressGatewayMgr *ingressgw.Manager
dnsServer dns.Server dnsServer dns.Server
@@ -266,6 +268,13 @@ func (e *Engine) Stop() error {
// stop/restore DNS first so dbus and friends don't complain because of a missing interface // stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer() e.stopDNSServer()
if e.ingressGatewayMgr != nil {
if err := e.ingressGatewayMgr.Close(); err != nil {
log.Warnf("failed to cleanup forward rules: %v", err)
}
e.ingressGatewayMgr = nil
}
if e.routeManager != nil { if e.routeManager != nil {
e.routeManager.Stop(e.stateManager) e.routeManager.Stop(e.stateManager)
} }
@@ -469,15 +478,15 @@ func (e *Engine) initFirewall() error {
} }
rosenpassPort := e.rpManager.GetAddress().Port rosenpassPort := e.rpManager.GetAddress().Port
port := manager.Port{Values: []uint16{uint16(rosenpassPort)}} port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
// this rule is static and will be torn down on engine down by the firewall manager // this rule is static and will be torn down on engine down by the firewall manager
if _, err := e.firewall.AddPeerFiltering( if _, err := e.firewall.AddPeerFiltering(
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
manager.ProtocolUDP, firewallManager.ProtocolUDP,
nil, nil,
&port, &port,
manager.ActionAccept, firewallManager.ActionAccept,
"", "",
"", "",
); err != nil { ); err != nil {
@@ -505,10 +514,10 @@ func (e *Engine) blockLanAccess() {
if _, err := e.firewall.AddRouteFiltering( if _, err := e.firewall.AddRouteFiltering(
[]netip.Prefix{v4}, []netip.Prefix{v4},
network, network,
manager.ProtocolALL, firewallManager.ProtocolALL,
nil, nil,
nil, nil,
manager.ActionDrop, firewallManager.ActionDrop,
); err != nil { ); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err)) merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err))
} }
@@ -912,6 +921,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update clientRoutes, err: %v", err) log.Errorf("failed to update clientRoutes, err: %v", err)
} }
// Ingress forward rules
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
log.Errorf("failed to update forward rules, err: %v", err)
}
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers()) e.updateOfflinePeers(networkMap.GetOfflinePeers())
@@ -1362,7 +1376,7 @@ func (e *Engine) close() {
} }
if e.firewall != nil { if e.firewall != nil {
err := e.firewall.Reset(e.stateManager) err := e.firewall.Close(e.stateManager)
if err != nil { if err != nil {
log.Warnf("failed to reset firewall: %s", err) log.Warnf("failed to reset firewall: %s", err)
} }
@@ -1482,7 +1496,7 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
} }
// GetFirewallManager returns the firewall manager // GetFirewallManager returns the firewall manager
func (e *Engine) GetFirewallManager() manager.Manager { func (e *Engine) GetFirewallManager() firewallManager.Manager {
return e.firewall return e.firewall
} }
@@ -1575,16 +1589,19 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns) return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
} }
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() { func (e *Engine) restartEngine() {
log.Info("restarting engine") e.syncMsgMux.Lock()
CtxGetState(e.ctx).Set(StatusConnecting) defer e.syncMsgMux.Unlock()
if err := e.Stop(); err != nil { if e.ctx.Err() != nil {
log.Errorf("Failed to stop engine: %v", err) return
} }
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection) _ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
log.Infof("cancelling client, engine will be recreated") log.Infof("cancelling client context, engine will be recreated")
e.clientCancel() e.clientCancel()
} }
@@ -1596,34 +1613,17 @@ func (e *Engine) startNetworkMonitor() {
e.networkMonitor = networkmonitor.New() e.networkMonitor = networkmonitor.New()
go func() { go func() {
var mu sync.Mutex if err := e.networkMonitor.Listen(e.ctx); err != nil {
var debounceTimer *time.Timer if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped")
// Start the network monitor with a callback, Start will block until the monitor is stopped, return
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
} }
log.Errorf("network monitor error: %v", err)
// Set a new timer to debounce rapid network changes return
debounceTimer = time.AfterFunc(2*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
} }
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
}() }()
} }
@@ -1770,6 +1770,74 @@ func (e *Engine) Address() (netip.Addr, error) {
return ip.Unmap(), nil return ip.Unmap(), nil
} }
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
if e.firewall == nil {
log.Warn("firewall is disabled, not updating forwarding rules")
return nil
}
if len(rules) == 0 {
if e.ingressGatewayMgr == nil {
return nil
}
err := e.ingressGatewayMgr.Close()
e.ingressGatewayMgr = nil
e.statusRecorder.SetIngressGwMgr(nil)
return err
}
if e.ingressGatewayMgr == nil {
mgr := ingressgw.NewManager(e.firewall)
e.ingressGatewayMgr = mgr
e.statusRecorder.SetIngressGwMgr(mgr)
}
var merr *multierror.Error
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
for _, rule := range rules {
proto, err := convertToFirewallProtocol(rule.GetProtocol())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to convert protocol '%s': %w", rule.GetProtocol(), err))
continue
}
dstPortInfo, err := convertPortInfo(rule.GetDestinationPort())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("invalid destination port '%v': %w", rule.GetDestinationPort(), err))
continue
}
translateIP, err := convertToIP(rule.GetTranslatedAddress())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to convert translated address '%s': %w", rule.GetTranslatedAddress(), err))
continue
}
translatePort, err := convertPortInfo(rule.GetTranslatedPort())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("invalid translate port '%v': %w", rule.GetTranslatedPort(), err))
continue
}
forwardRule := firewallManager.ForwardRule{
Protocol: proto,
DestinationPort: *dstPortInfo,
TranslatedAddress: translateIP,
TranslatedPort: *translatePort,
}
forwardingRules = append(forwardingRules, forwardRule)
}
log.Infof("updating forwarding rules: %d", len(forwardingRules))
if err := e.ingressGatewayMgr.Update(forwardingRules); err != nil {
log.Errorf("failed to update forwarding rules: %v", err)
}
return nberrors.FormatErrorOrNil(merr)
}
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks { for _, check := range checks {

View File

@@ -31,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -44,6 +45,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -74,7 +76,7 @@ type MockWGIface struct {
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool IsUserspaceBindFunc func() bool
NameFunc func() string NameFunc func() string
AddressFunc func() device.WGAddress AddressFunc func() wgaddr.Address
ToInterfaceFunc func() *net.Interface ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error UpdateAddrFunc func(newAddr string) error
@@ -113,7 +115,7 @@ func (m *MockWGIface) Name() string {
return m.NameFunc() return m.NameFunc()
} }
func (m *MockWGIface) Address() device.WGAddress { func (m *MockWGIface) Address() wgaddr.Address {
return m.AddressFunc() return m.AddressFunc()
} }
@@ -363,8 +365,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemovePeerFunc: func(peerKey string) error { RemovePeerFunc: func(peerKey string) error {
return nil return nil
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -1433,7 +1435,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock())
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
@@ -20,7 +21,7 @@ type wgIfaceBase interface {
CreateOnAndroid(routeRange []string, ip string, domains []string) error CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool IsUserspaceBind() bool
Name() string Name() string
Address() device.WGAddress Address() wgaddr.Address
ToInterface() *net.Interface ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error

View File

@@ -0,0 +1,107 @@
package ingressgw
import (
"fmt"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
type DNATFirewall interface {
AddDNATRule(fwdRule firewall.ForwardRule) (firewall.Rule, error)
DeleteDNATRule(rule firewall.Rule) error
}
type RulePair struct {
firewall.ForwardRule
firewall.Rule
}
type Manager struct {
dnatFirewall DNATFirewall
rules map[string]RulePair // keys is the ID of the ForwardRule
rulesMu sync.Mutex
}
func NewManager(dnatFirewall DNATFirewall) *Manager {
return &Manager{
dnatFirewall: dnatFirewall,
rules: make(map[string]RulePair),
}
}
func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
h.rulesMu.Lock()
defer h.rulesMu.Unlock()
var mErr *multierror.Error
toDelete := make(map[string]RulePair, len(h.rules))
for id, r := range h.rules {
toDelete[id] = r
}
// Process new/updated rules
for _, fwdRule := range forwardRules {
id := fwdRule.ID()
if _, ok := h.rules[id]; ok {
delete(toDelete, id)
continue
}
rule, err := h.dnatFirewall.AddDNATRule(fwdRule)
if err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': %v", fwdRule.String(), err))
continue
}
log.Infof("forward rule has been added '%s'", fwdRule)
h.rules[id] = RulePair{
ForwardRule: fwdRule,
Rule: rule,
}
}
// Remove deleted rules
for id, rulePair := range toDelete {
if err := h.dnatFirewall.DeleteDNATRule(rulePair.Rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rulePair.ForwardRule.String(), err))
}
log.Infof("forward rule has been deleted '%s'", rulePair.ForwardRule)
delete(h.rules, id)
}
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) Close() error {
h.rulesMu.Lock()
defer h.rulesMu.Unlock()
log.Infof("clean up all (%d) forward rules", len(h.rules))
var mErr *multierror.Error
for _, rule := range h.rules {
if err := h.dnatFirewall.DeleteDNATRule(rule.Rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rule, err))
}
}
h.rules = make(map[string]RulePair)
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) Rules() []firewall.ForwardRule {
h.rulesMu.Lock()
defer h.rulesMu.Unlock()
rules := make([]firewall.ForwardRule, 0, len(h.rules))
for _, rulePair := range h.rules {
rules = append(rules, rulePair.ForwardRule)
}
return rules
}

View File

@@ -0,0 +1,281 @@
package ingressgw
import (
"fmt"
"net/netip"
"testing"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var (
_ firewall.Rule = (*MocFwRule)(nil)
_ DNATFirewall = &MockDNATFirewall{}
)
type MocFwRule struct {
id string
}
func (m *MocFwRule) ID() string {
return string(m.id)
}
type MockDNATFirewall struct {
throwError bool
}
func (m *MockDNATFirewall) AddDNATRule(fwdRule firewall.ForwardRule) (firewall.Rule, error) {
if m.throwError {
return nil, fmt.Errorf("moc error")
}
fwRule := &MocFwRule{
id: fwdRule.ID(),
}
return fwRule, nil
}
func (m *MockDNATFirewall) DeleteDNATRule(rule firewall.Rule) error {
if m.throwError {
return fmt.Errorf("moc error")
}
return nil
}
func (m *MockDNATFirewall) forceToThrowErrors() {
m.throwError = true
}
func TestManager_AddRule(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
updates := []firewall.ForwardRule{
{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
},
{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}}
if err := mgr.Update(updates); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != len(updates) {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_UpdateRule(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
ruleUDP := firewall.ForwardRule{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleUDP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 1 {
t.Errorf("unexpected rules count: %d", len(rules))
}
if rules[0].TranslatedAddress.String() != ruleUDP.TranslatedAddress.String() {
t.Errorf("unexpected rule: %v", rules[0])
}
if rules[0].TranslatedPort.String() != ruleUDP.TranslatedPort.String() {
t.Errorf("unexpected rule: %v", rules[0])
}
if rules[0].DestinationPort.String() != ruleUDP.DestinationPort.String() {
t.Errorf("unexpected rule: %v", rules[0])
}
if rules[0].Protocol != ruleUDP.Protocol {
t.Errorf("unexpected rule: %v", rules[0])
}
}
func TestManager_ExtendRules(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
ruleUDP := firewall.ForwardRule{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP, ruleUDP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 2 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_UnderlingError(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
ruleUDP := firewall.ForwardRule{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
fw.forceToThrowErrors()
if err := mgr.Update([]firewall.ForwardRule{ruleTCP, ruleUDP}); err == nil {
t.Errorf("expected error")
}
rules := mgr.Rules()
if len(rules) != 1 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_Cleanup(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Update([]firewall.ForwardRule{}); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 0 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_DeleteBrokenRule(t *testing.T) {
fw := &MockDNATFirewall{}
// force to throw errors when Add DNAT Rule
fw.forceToThrowErrors()
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err == nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 0 {
t.Errorf("unexpected rules count: %d", len(rules))
}
// simulate that to remove a broken rule
if err := mgr.Update([]firewall.ForwardRule{}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Close(); err != nil {
t.Errorf("unexpected error: %v", err)
}
}
func TestManager_Close(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Close(); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 0 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}

View File

@@ -0,0 +1,58 @@
package internal
import (
"errors"
"fmt"
"net"
"net/netip"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewallManager.Protocol, error) {
switch protocol {
case mgmProto.RuleProtocol_TCP:
return firewallManager.ProtocolTCP, nil
case mgmProto.RuleProtocol_UDP:
return firewallManager.ProtocolUDP, nil
case mgmProto.RuleProtocol_ICMP:
return firewallManager.ProtocolICMP, nil
case mgmProto.RuleProtocol_ALL:
return firewallManager.ProtocolALL, nil
default:
return "", fmt.Errorf("invalid protocol type: %s", protocol.String())
}
}
func convertPortInfo(portInfo *mgmProto.PortInfo) (*firewallManager.Port, error) {
if portInfo == nil {
return nil, errors.New("portInfo cannot be nil")
}
if portInfo.GetPort() != 0 {
return firewallManager.NewPort(int(portInfo.GetPort()))
}
if portInfo.GetRange() != nil {
return firewallManager.NewPort(int(portInfo.GetRange().Start), int(portInfo.GetRange().End))
}
return nil, fmt.Errorf("invalid portInfo: %v", portInfo)
}
func convertToIP(rawIP []byte) (netip.Addr, error) {
if rawIP == nil {
return netip.Addr{}, errors.New("input bytes cannot be nil")
}
if len(rawIP) != net.IPv4len && len(rawIP) != net.IPv6len {
return netip.Addr{}, fmt.Errorf("invalid IP length: %d", len(rawIP))
}
if len(rawIP) == net.IPv4len {
return netip.AddrFrom4([4]byte(rawIP)), nil
}
return netip.AddrFrom16([16]byte(rawIP)), nil
}

View File

@@ -16,7 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil { if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err) return fmt.Errorf("failed to open routing socket: %v", err)
@@ -28,18 +28,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
} }
}() }()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket: %v", err)
}
}()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ctx.Err()
default: default:
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, err := unix.Read(fd, buf) n, err := unix.Read(fd, buf)
@@ -76,11 +68,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
switch msg.Type { switch msg.Type {
case unix.RTM_ADD: case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
go callback() return nil
case unix.RTM_DELETE: case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
go callback() return nil
} }
} }
} }

View File

@@ -14,7 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
if nexthopv4.Intf == nil && nexthopv6.Intf == nil { if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
return errors.New("no interfaces available") return errors.New("no interfaces available")
} }
@@ -31,8 +31,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ctx.Err()
// handle route changes // handle route changes
case route := <-routeChan: case route := <-routeChan:
// default route and main table // default route and main table
@@ -43,12 +42,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
// triggered on added/replaced routes // triggered on added/replaced routes
case syscall.RTM_NEWROUTE: case syscall.RTM_NEWROUTE:
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex) log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil return nil
case syscall.RTM_DELROUTE: case syscall.RTM_DELROUTE:
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) { if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex) log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil return nil
} }
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
routeMonitor, err := systemops.NewRouteMonitor(ctx) routeMonitor, err := systemops.NewRouteMonitor(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create route monitor: %w", err) return fmt.Errorf("failed to create route monitor: %w", err)
@@ -24,20 +24,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ctx.Err()
case route := <-routeMonitor.RouteUpdates(): case route := <-routeMonitor.RouteUpdates():
if route.Destination.Bits() != 0 { if route.Destination.Bits() != 0 {
continue continue
} }
if routeChanged(route, nexthopv4, nexthopv6, callback) { if routeChanged(route, nexthopv4, nexthopv6) {
break return nil
} }
} }
} }
} }
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool { func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool {
intf := "<nil>" intf := "<nil>"
if route.Interface != nil { if route.Interface != nil {
intf = route.Interface.Name intf = route.Interface.Name
@@ -51,18 +51,15 @@ func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Ne
case systemops.RouteModified: case systemops.RouteModified:
// TODO: get routing table to figure out if our route is affected for modified routes // TODO: get routing table to figure out if our route is affected for modified routes
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf) log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
go callback()
return true return true
case systemops.RouteAdded: case systemops.RouteAdded:
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP { if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf) log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
go callback()
return true return true
} }
case systemops.RouteDeleted: case systemops.RouteDeleted:
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP { if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf) log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
go callback()
return true return true
} }
} }

View File

@@ -1,12 +1,27 @@
//go:build !ios && !android
package networkmonitor package networkmonitor
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net/netip"
"runtime/debug"
"sync" "sync"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
var ErrStopped = errors.New("monitor has been stopped") const (
debounceTime = 2 * time.Second
)
var checkChangeFn = checkChange
// NetworkMonitor watches for changes in network configuration. // NetworkMonitor watches for changes in network configuration.
type NetworkMonitor struct { type NetworkMonitor struct {
@@ -19,3 +34,99 @@ type NetworkMonitor struct {
func New() *NetworkMonitor { func New() *NetworkMonitor {
return &NetworkMonitor{} return &NetworkMonitor{}
} }
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
nw.mu.Lock()
if nw.cancel != nil {
nw.mu.Unlock()
return errors.New("network monitor already started")
}
ctx, nw.cancel = context.WithCancel(ctx)
defer nw.cancel()
nw.wg.Add(1)
nw.mu.Unlock()
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
return nil
}
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
event := make(chan struct{}, 1)
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
// debounce changes
timer := time.NewTimer(0)
timer.Stop()
for {
select {
case <-event:
timer.Reset(debounceTime)
case <-timer.C:
return nil
case <-ctx.Done():
timer.Stop()
return ctx.Err()
}
}
}
// Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel == nil {
return
}
nw.cancel()
nw.wg.Wait()
}
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
for {
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
close(event)
return
}
// prevent blocking
select {
case event <- struct{}{}:
default:
}
}
}

View File

@@ -1,82 +0,0 @@
//go:build !ios && !android
package networkmonitor
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime/debug"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
if ctx.Err() != nil {
return ctx.Err()
}
nw.mu.Lock()
ctx, nw.cancel = context.WithCancel(ctx)
nw.mu.Unlock()
nw.wg.Add(1)
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
return nil
}
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
return fmt.Errorf("check change: %w", err)
}
return nil
}
// Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel != nil {
nw.cancel()
nw.wg.Wait()
}
}

View File

@@ -2,10 +2,21 @@
package networkmonitor package networkmonitor
import "context" import (
"context"
"fmt"
)
func (nw *NetworkMonitor) Start(context.Context, func()) error { type NetworkMonitor struct {
return nil }
// New creates a new network monitor.
func New() *NetworkMonitor {
return &NetworkMonitor{}
}
func (nw *NetworkMonitor) Listen(_ context.Context) error {
return fmt.Errorf("network monitor not supported on mobile platforms")
} }
func (nw *NetworkMonitor) Stop() { func (nw *NetworkMonitor) Stop() {

View File

@@ -0,0 +1,99 @@
package networkmonitor
import (
"context"
"errors"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
type MocMultiEvent struct {
counter int
}
func (m *MocMultiEvent) checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
if m.counter == 0 {
<-ctx.Done()
return ctx.Err()
}
time.Sleep(1 * time.Second)
m.counter--
return nil
}
func TestNetworkMonitor_Close(t *testing.T) {
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
<-ctx.Done()
return ctx.Err()
}
nw := New()
var resErr error
done := make(chan struct{})
go func() {
resErr = nw.Listen(context.Background())
close(done)
}()
time.Sleep(1 * time.Second) // wait for the goroutine to start
nw.Stop()
<-done
if !errors.Is(resErr, context.Canceled) {
t.Errorf("unexpected error: %v", resErr)
}
}
func TestNetworkMonitor_Event(t *testing.T) {
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
timeout, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
select {
case <-ctx.Done():
return ctx.Err()
case <-timeout.Done():
return nil
}
}
nw := New()
defer nw.Stop()
var resErr error
done := make(chan struct{})
go func() {
resErr = nw.Listen(context.Background())
close(done)
}()
<-done
if !errors.Is(resErr, nil) {
t.Errorf("unexpected error: %v", nil)
}
}
func TestNetworkMonitor_MultiEvent(t *testing.T) {
eventsRepeated := 3
me := &MocMultiEvent{counter: eventsRepeated}
checkChangeFn = me.checkChange
nw := New()
defer nw.Stop()
done := make(chan struct{})
started := time.Now()
go func() {
if resErr := nw.Listen(context.Background()); resErr != nil {
t.Errorf("unexpected error: %v", resErr)
}
close(done)
}()
<-done
expectedResponseTime := time.Duration(eventsRepeated)*time.Second + debounceTime
if time.Since(started) < expectedResponseTime {
t.Errorf("unexpected duration: %v", time.Since(started))
}
}

View File

@@ -114,6 +114,9 @@ type Conn struct {
guard *guard.Guard guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup semaphore *semaphoregroup.SemaphoreGroup
// debug purpose
dumpState *stateDump
} }
// NewConn creates a new not opened Conn to the remote peer. // NewConn creates a new not opened Conn to the remote peer.
@@ -137,10 +140,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
statusRelay: NewAtomicConnStatus(), statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(),
semaphore: semaphore, semaphore: semaphore,
dumpState: newStateDump(connLog),
} }
ctrl := isController(config) ctrl := isController(config)
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager) conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager, conn.dumpState)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally) workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
@@ -160,6 +164,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
go conn.handshaker.Listen() go conn.handshaker.Listen()
go conn.dumpState.Start(ctx)
return conn, nil return conn, nil
} }
@@ -193,6 +198,7 @@ func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
defer conn.semaphore.Done(conn.ctx) defer conn.semaphore.Done(conn.ctx)
conn.waitInitialRandomSleepTime(ctx) conn.waitInitialRandomSleepTime(ctx)
conn.dumpState.SendOffer()
err := conn.handshaker.sendOffer() err := conn.handshaker.sendOffer()
if err != nil { if err != nil {
conn.log.Errorf("failed to send initial offer: %v", err) conn.log.Errorf("failed to send initial offer: %v", err)
@@ -251,12 +257,14 @@ func (conn *Conn) Close() {
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise // OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
// doesn't block, discards the message if connection wasn't ready // doesn't block, discards the message if connection wasn't ready
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool { func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
conn.log.Debugf("OnRemoteAnswer, status ICE: %s, status relay: %s", conn.statusICE, conn.statusRelay) conn.dumpState.RemoteAnswer()
conn.log.Infof("OnRemoteAnswer, status ICE: %s, status relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteAnswer(answer) return conn.handshaker.OnRemoteAnswer(answer)
} }
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
conn.dumpState.RemoteCandidate()
conn.workerICE.OnRemoteCandidate(candidate, haRoutes) conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
} }
@@ -278,7 +286,8 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
} }
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool { func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
conn.log.Debugf("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay) conn.dumpState.RemoteOffer()
conn.log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
return conn.handshaker.OnRemoteOffer(offer) return conn.handshaker.OnRemoteOffer(offer)
} }
@@ -322,6 +331,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
} }
conn.log.Infof("set ICE to active connection") conn.log.Infof("set ICE to active connection")
conn.dumpState.P2PConnected()
var ( var (
ep *net.UDPAddr ep *net.UDPAddr
@@ -329,6 +339,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
err error err error
) )
if iceConnInfo.RelayedOnLocal { if iceConnInfo.RelayedOnLocal {
conn.dumpState.NewLocalProxy()
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn) wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
if err != nil { if err != nil {
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
@@ -390,6 +401,7 @@ func (conn *Conn) onICEStateDisconnected() {
// switch back to relay connection // switch back to relay connection
if conn.isReadyToUpgrade() { if conn.isReadyToUpgrade() {
conn.log.Infof("ICE disconnected, set Relay to active connection") conn.log.Infof("ICE disconnected, set Relay to active connection")
conn.dumpState.SwitchToRelay()
conn.wgProxyRelay.Work() conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
@@ -432,6 +444,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
return return
} }
conn.dumpState.RelayConnected()
conn.log.Debugf("Relay connection has been established, setup the WireGuard") conn.log.Debugf("Relay connection has been established, setup the WireGuard")
wgProxy, err := conn.newProxy(rci.relayedConn) wgProxy, err := conn.newProxy(rci.relayedConn)
@@ -439,11 +452,12 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
return return
} }
conn.dumpState.NewLocalProxy()
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
if conn.iceP2PIsActive() { if conn.isICEActive() {
conn.log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String()) conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
conn.setRelayedProxy(wgProxy) conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected) conn.statusRelay.Set(StatusConnected)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
@@ -481,10 +495,10 @@ func (conn *Conn) onRelayDisconnected() {
return return
} }
conn.log.Debugf("relay connection is disconnected") conn.log.Infof("relay connection is disconnected")
if conn.currentConnPriority == connPriorityRelay { if conn.currentConnPriority == connPriorityRelay {
conn.log.Debugf("clean up WireGuard config") conn.log.Infof("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil { if err := conn.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err) conn.log.Errorf("failed to remove wg endpoint: %v", err)
} }
@@ -516,7 +530,8 @@ func (conn *Conn) listenGuardEvent(ctx context.Context) {
for { for {
select { select {
case <-conn.guard.Reconnect: case <-conn.guard.Reconnect:
conn.log.Debugf("send offer to peer") conn.log.Infof("send offer to peer")
conn.dumpState.SendOffer()
if err := conn.handshaker.SendOffer(); err != nil { if err := conn.handshaker.SendOffer(); err != nil {
conn.log.Errorf("failed to send offer: %v", err) conn.log.Errorf("failed to send offer: %v", err)
} }
@@ -711,8 +726,8 @@ func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
} }
func (conn *Conn) iceP2PIsActive() bool { func (conn *Conn) isICEActive() bool {
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected
} }
func (conn *Conn) removeWgPeer() error { func (conn *Conn) removeWgPeer() error {

View File

@@ -76,19 +76,19 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
func (h *Handshaker) Listen() { func (h *Handshaker) Listen() {
for { for {
h.log.Debugf("wait for remote offer confirmation") h.log.Info("wait for remote offer confirmation")
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation() remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
if err != nil { if err != nil {
var connectionClosedError *ConnectionClosedError var connectionClosedError *ConnectionClosedError
if errors.As(err, &connectionClosedError) { if errors.As(err, &connectionClosedError) {
h.log.Tracef("stop handshaker") h.log.Info("exit from handshaker")
return return
} }
h.log.Errorf("failed to received remote offer confirmation: %s", err) h.log.Errorf("failed to received remote offer confirmation: %s", err)
continue continue
} }
h.log.Debugf("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort) h.log.Infof("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
for _, listener := range h.onNewOfferListeners { for _, listener := range h.onNewOfferListeners {
go listener(remoteOfferAnswer) go listener(remoteOfferAnswer)
} }
@@ -108,7 +108,7 @@ func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool {
case h.remoteOffersCh <- offer: case h.remoteOffersCh <- offer:
return true return true
default: default:
h.log.Debugf("OnRemoteOffer skipping message because is not ready") h.log.Warnf("OnRemoteOffer skipping message because is not ready")
// connection might not be ready yet to receive so we ignore the message // connection might not be ready yet to receive so we ignore the message
return false return false
} }
@@ -131,8 +131,7 @@ func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
select { select {
case remoteOfferAnswer := <-h.remoteOffersCh: case remoteOfferAnswer := <-h.remoteOffersCh:
// received confirmation from the remote peer -> ready to proceed // received confirmation from the remote peer -> ready to proceed
err := h.sendAnswer() if err := h.sendAnswer(); err != nil {
if err != nil {
return nil, err return nil, err
} }
return &remoteOfferAnswer, nil return &remoteOfferAnswer, nil
@@ -168,7 +167,7 @@ func (h *Handshaker) sendOffer() error {
} }
func (h *Handshaker) sendAnswer() error { func (h *Handshaker) sendAnswer() error {
h.log.Debugf("sending answer") h.log.Infof("sending answer")
uFrag, pwd := h.ice.GetLocalUserCredentials() uFrag, pwd := h.ice.GetLocalUserCredentials()
answer := OfferAnswer{ answer := OfferAnswer{

View File

@@ -8,6 +8,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
@@ -16,4 +17,5 @@ type WGIface interface {
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
GetStats(peerKey string) (configurer.WGStats, error) GetStats(peerKey string) (configurer.WGStats, error)
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Address() wgaddr.Address
} }

View File

@@ -0,0 +1,112 @@
package peer
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
type stateDump struct {
log *log.Entry
sentOffer int
remoteOffer int
remoteAnswer int
remoteCandidate int
p2pConnected int
switchToRelay int
wgCheckSuccess int
relayConnected int
localProxies int
mu sync.Mutex
}
func newStateDump(log *log.Entry) *stateDump {
return &stateDump{
log: log,
}
}
func (s *stateDump) Start(ctx context.Context) {
ticker := time.NewTicker(10 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.dumpState()
case <-ctx.Done():
return
}
}
}
func (s *stateDump) RemoteOffer() {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteOffer++
}
func (s *stateDump) RemoteCandidate() {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteCandidate++
}
func (s *stateDump) SendOffer() {
s.mu.Lock()
defer s.mu.Unlock()
s.sentOffer++
}
func (s *stateDump) dumpState() {
s.mu.Lock()
defer s.mu.Unlock()
s.log.Infof("Dump stat: SentOffer: %d, RemoteOffer: %d, RemoteAnswer: %d, RemoteCandidate: %d, P2PConnected: %d, SwitchToRelay: %d, WGCheckSuccess: %d, RelayConnected: %d, LocalProxies: %d",
s.sentOffer, s.remoteOffer, s.remoteAnswer, s.remoteCandidate, s.p2pConnected, s.switchToRelay, s.wgCheckSuccess, s.relayConnected, s.localProxies)
}
func (s *stateDump) RemoteAnswer() {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteAnswer++
}
func (s *stateDump) P2PConnected() {
s.mu.Lock()
defer s.mu.Unlock()
s.p2pConnected++
}
func (s *stateDump) SwitchToRelay() {
s.mu.Lock()
defer s.mu.Unlock()
s.switchToRelay++
}
func (s *stateDump) WGcheckSuccess() {
s.mu.Lock()
defer s.mu.Unlock()
s.wgCheckSuccess++
}
func (s *stateDump) RelayConnected() {
s.mu.Lock()
defer s.mu.Unlock()
s.relayConnected++
}
func (s *stateDump) NewLocalProxy() {
s.mu.Lock()
defer s.mu.Unlock()
s.localProxies++
}

View File

@@ -14,7 +14,9 @@ import (
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
@@ -132,13 +134,14 @@ type NSGroupState struct {
// FullStatus contains the full state held by the Status instance // FullStatus contains the full state held by the Status instance
type FullStatus struct { type FullStatus struct {
Peers []State Peers []State
ManagementState ManagementState ManagementState ManagementState
SignalState SignalState SignalState SignalState
LocalPeerState LocalPeerState LocalPeerState LocalPeerState
RosenpassState RosenpassState RosenpassState RosenpassState
Relays []relay.ProbeResult Relays []relay.ProbeResult
NSGroupStates []NSGroupState NSGroupStates []NSGroupState
NumOfForwardingRules int
} }
// Status holds a state of peers, signal, management connections and relays // Status holds a state of peers, signal, management connections and relays
@@ -171,6 +174,8 @@ type Status struct {
eventMux sync.RWMutex eventMux sync.RWMutex
eventStreams map[string]chan *proto.SystemEvent eventStreams map[string]chan *proto.SystemEvent
eventQueue *EventQueue eventQueue *EventQueue
ingressGwMgr *ingressgw.Manager
} }
// NewRecorder returns a new Status instance // NewRecorder returns a new Status instance
@@ -193,6 +198,12 @@ func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
d.relayMgr = manager d.relayMgr = manager
} }
func (d *Status) SetIngressGwMgr(ingressGwMgr *ingressgw.Manager) {
d.mux.Lock()
defer d.mux.Unlock()
d.ingressGwMgr = ingressGwMgr
}
// ReplaceOfflinePeers replaces // ReplaceOfflinePeers replaces
func (d *Status) ReplaceOfflinePeers(replacement []State) { func (d *Status) ReplaceOfflinePeers(replacement []State) {
d.mux.Lock() d.mux.Lock()
@@ -235,6 +246,18 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
return state, nil return state, nil
} }
func (d *Status) PeerByIP(ip string) (string, bool) {
d.mux.Lock()
defer d.mux.Unlock()
for _, state := range d.peers {
if state.IP == ip {
return state.FQDN, true
}
}
return "", false
}
// RemovePeer removes peer from Daemon status map // RemovePeer removes peer from Daemon status map
func (d *Status) RemovePeer(peerPubKey string) error { func (d *Status) RemovePeer(peerPubKey string) error {
d.mux.Lock() d.mux.Lock()
@@ -734,6 +757,16 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return append(relayStates, relayState) return append(relayStates, relayState)
} }
func (d *Status) ForwardingRules() []firewall.ForwardRule {
d.mux.Lock()
defer d.mux.Unlock()
if d.ingressGwMgr == nil {
return nil
}
return d.ingressGwMgr.Rules()
}
func (d *Status) GetDNSStates() []NSGroupState { func (d *Status) GetDNSStates() []NSGroupState {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@@ -751,11 +784,12 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo
// GetFullStatus gets full status // GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus { func (d *Status) GetFullStatus() FullStatus {
fullStatus := FullStatus{ fullStatus := FullStatus{
ManagementState: d.GetManagementState(), ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(), SignalState: d.GetSignalState(),
Relays: d.GetRelayStates(), Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(), RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(), NSGroupStates: d.GetDNSStates(),
NumOfForwardingRules: len(d.ForwardingRules()),
} }
d.mux.Lock() d.mux.Lock()

View File

@@ -27,6 +27,7 @@ type WGWatcher struct {
log *log.Entry log *log.Entry
wgIfaceStater WGInterfaceStater wgIfaceStater WGInterfaceStater
peerKey string peerKey string
stateDump *stateDump
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
@@ -34,11 +35,12 @@ type WGWatcher struct {
waitGroup sync.WaitGroup waitGroup sync.WaitGroup
} }
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher { func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
return &WGWatcher{ return &WGWatcher{
log: log, log: log,
wgIfaceStater: wgIfaceStater, wgIfaceStater: wgIfaceStater,
peerKey: peerKey, peerKey: peerKey,
stateDump: stateDump,
} }
} }
@@ -105,6 +107,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
resetTime := time.Until(handshake.Add(checkPeriod)) resetTime := time.Until(handshake.Add(checkPeriod))
timer.Reset(resetTime) timer.Reset(resetTime)
w.stateDump.WGcheckSuccess()
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime) w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
case <-ctx.Done(): case <-ctx.Done():

View File

@@ -43,7 +43,7 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
mlog := log.WithField("peer", "tet") mlog := log.WithField("peer", "tet")
mocWgIface := &MocWgIface{} mocWgIface := &MocWgIface{}
watcher := NewWGWatcher(mlog, mocWgIface, "") watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump(mlog))
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@@ -72,7 +72,7 @@ func TestWGWatcher_ReEnable(t *testing.T) {
mlog := log.WithField("peer", "tet") mlog := log.WithField("peer", "tet")
mocWgIface := &MocWgIface{} mocWgIface := &MocWgIface{}
watcher := NewWGWatcher(mlog, mocWgIface, "") watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump(mlog))
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()

View File

@@ -358,6 +358,12 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
} }
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
var routePrefixes []netip.Prefix var routePrefixes []netip.Prefix
for _, routes := range clientRoutes { for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil { if len(routes) > 0 && routes[0] != nil {
@@ -365,14 +371,8 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool
} }
} }
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
for _, prefix := range routePrefixes { for _, prefix := range routePrefixes {
// default route is // default route is handled by route exclusion / ip rules
if prefix.Bits() == 0 { if prefix.Bits() == 0 {
continue continue
} }

View File

@@ -33,14 +33,14 @@ type WorkerRelay struct {
wgWatcher *WGWatcher wgWatcher *WGWatcher
} }
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService) *WorkerRelay { func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
r := &WorkerRelay{ r := &WorkerRelay{
log: log, log: log,
isController: ctrl, isController: ctrl,
config: config, config: config,
conn: conn, conn: conn,
relayManager: relayManager, relayManager: relayManager,
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key), wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key, stateDump),
} }
return r return r
} }

View File

@@ -302,7 +302,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error
// If the chosen route is the same as the current route, do nothing // If the chosen route is the same as the current route, do nothing
if c.currentChosen != nil && c.currentChosen.ID == newChosenID && if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
c.currentChosen.IsEqual(c.routes[newChosenID]) { c.currentChosen.Equal(c.routes[newChosenID]) {
return nil return nil
} }

View File

@@ -160,6 +160,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
client := &dns.Client{ client := &dns.Client{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Net: "udp", Net: "udp",

View File

@@ -3,9 +3,9 @@ package iface
import ( import (
"net" "net"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type wgIfaceBase interface { type wgIfaceBase interface {
@@ -13,7 +13,7 @@ type wgIfaceBase interface {
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
ToInterface() *net.Interface ToInterface() *net.Interface
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter

View File

@@ -0,0 +1,51 @@
package ipfwdstate
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// IPForwardingState is a struct that keeps track of the IP forwarding state.
// todo: read initial state of the IP forwarding from the system and reset the state based on it
type IPForwardingState struct {
enabledCounter int
}
func NewIPForwardingState() *IPForwardingState {
return &IPForwardingState{}
}
func (f *IPForwardingState) RequestForwarding() error {
if f.enabledCounter != 0 {
f.enabledCounter++
return nil
}
if err := systemops.EnableIPForwarding(); err != nil {
return fmt.Errorf("failed to enable IP forwarding with sysctl: %w", err)
}
f.enabledCounter = 1
log.Info("IP forwarding enabled")
return nil
}
func (f *IPForwardingState) ReleaseForwarding() error {
if f.enabledCounter == 0 {
return nil
}
if f.enabledCounter > 1 {
f.enabledCounter--
return nil
}
// if failed to disable IP forwarding we anyway decrement the counter
f.enabledCounter = 0
// todo call systemops.DisableIPForwarding()
return nil
}

View File

@@ -13,7 +13,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -41,7 +40,7 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
for routeID := range m.routes { for routeID := range m.routes {
update, found := routesMap[routeID] update, found := routesMap[routeID]
if !found || !update.IsEqual(m.routes[routeID]) { if !found || !update.Equal(m.routes[routeID]) {
serverRoutesToRemove = append(serverRoutesToRemove, routeID) serverRoutesToRemove = append(serverRoutesToRemove, routeID)
} }
} }
@@ -71,9 +70,6 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
} }
if len(m.routes) > 0 { if len(m.routes) > 0 {
if err := systemops.EnableIPForwarding(); err != nil {
return fmt.Errorf("enable ip forwarding: %w", err)
}
if err := m.firewall.EnableRouting(); err != nil { if err := m.firewall.EnableRouting(); err != nil {
return fmt.Errorf("enable routing: %w", err) return fmt.Errorf("enable routing: %w", err)
} }

Some files were not shown because too many files have changed in this diff Show More