mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-11 03:09:55 +00:00
Compare commits
27 Commits
feature/ba
...
v0.38.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ef476b014 | ||
|
|
6f82e96d6a | ||
|
|
a2faae5d62 | ||
|
|
4a3cbcd38a | ||
|
|
c2980bc8cf | ||
|
|
67ae871ce4 | ||
|
|
39ff5e833a | ||
|
|
cd9eff5331 | ||
|
|
80ceb80197 | ||
|
|
636a0e2475 | ||
|
|
e66e329bf6 | ||
|
|
aaa23beeec | ||
|
|
6bef474e9e | ||
|
|
81040ff80a | ||
|
|
c73481aee4 | ||
|
|
fc1da94520 | ||
|
|
ae6b61301c | ||
|
|
a444e551b3 | ||
|
|
53b9a2002f | ||
|
|
4b76d93cec | ||
|
|
062d1ec76f | ||
|
|
c111675dd8 | ||
|
|
60ffe0dc87 | ||
|
|
bcc5824980 | ||
|
|
af5796de1c | ||
|
|
9d604b7e66 | ||
|
|
82c12cc8ae |
25
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
25
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -31,14 +31,22 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
|
|||||||
|
|
||||||
`netbird version`
|
`netbird version`
|
||||||
|
|
||||||
**NetBird status -dA output:**
|
**Is any other VPN software installed?**
|
||||||
|
|
||||||
If applicable, add the `netbird status -dA' command output.
|
If yes, which one?
|
||||||
|
|
||||||
**Do you face any (non-mobile) client issues?**
|
**Debug output**
|
||||||
|
|
||||||
Please provide the file created by `netbird debug for 1m -AS`.
|
To help us resolve the problem, please attach the following debug output
|
||||||
We advise reviewing the anonymized files for any remaining PII.
|
|
||||||
|
netbird status -dA
|
||||||
|
|
||||||
|
As well as the file created by
|
||||||
|
|
||||||
|
netbird debug for 1m -AS
|
||||||
|
|
||||||
|
|
||||||
|
We advise reviewing the anonymized output for any remaining personal information.
|
||||||
|
|
||||||
**Screenshots**
|
**Screenshots**
|
||||||
|
|
||||||
@@ -47,3 +55,10 @@ If applicable, add screenshots to help explain your problem.
|
|||||||
**Additional context**
|
**Additional context**
|
||||||
|
|
||||||
Add any other context about the problem here.
|
Add any other context about the problem here.
|
||||||
|
|
||||||
|
**Have you tried these troubleshooting steps?**
|
||||||
|
- [ ] Checked for newer NetBird versions
|
||||||
|
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||||
|
- [ ] Restarted the NetBird client
|
||||||
|
- [ ] Disabled other VPN software
|
||||||
|
- [ ] Checked firewall settings
|
||||||
|
|||||||
10
.github/workflows/golang-test-linux.yml
vendored
10
.github/workflows/golang-test-linux.yml
vendored
@@ -258,7 +258,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ 'amd64' ]
|
||||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
@@ -325,8 +325,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ 'amd64' ]
|
||||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
store: [ 'sqlite', 'postgres' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
@@ -392,7 +392,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ 'amd64' ]
|
||||||
store: [ 'sqlite', 'postgres' ]
|
store: [ 'sqlite', 'postgres' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
@@ -461,7 +461,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ 'amd64' ]
|
||||||
store: [ 'sqlite', 'postgres']
|
store: [ 'sqlite', 'postgres']
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
4
.github/workflows/release.yml
vendored
4
.github/workflows/release.yml
vendored
@@ -71,7 +71,7 @@ jobs:
|
|||||||
- name: Install goversioninfo
|
- name: Install goversioninfo
|
||||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||||
- name: Generate windows syso amd64
|
- name: Generate windows syso amd64
|
||||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
@@ -150,7 +150,7 @@ jobs:
|
|||||||
- name: Install goversioninfo
|
- name: Install goversioninfo
|
||||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||||
- name: Generate windows syso amd64
|
- name: Generate windows syso amd64
|
||||||
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||||
|
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
|
|||||||
@@ -53,9 +53,9 @@ nfpms:
|
|||||||
scripts:
|
scripts:
|
||||||
postinstall: "release_files/ui-post-install.sh"
|
postinstall: "release_files/ui-post-install.sh"
|
||||||
contents:
|
contents:
|
||||||
- src: client/ui/netbird.desktop
|
- src: client/ui/build/netbird.desktop
|
||||||
dst: /usr/share/applications/netbird.desktop
|
dst: /usr/share/applications/netbird.desktop
|
||||||
- src: client/ui/netbird.png
|
- src: client/ui/assets/netbird.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- netbird
|
- netbird
|
||||||
@@ -72,9 +72,9 @@ nfpms:
|
|||||||
scripts:
|
scripts:
|
||||||
postinstall: "release_files/ui-post-install.sh"
|
postinstall: "release_files/ui-post-install.sh"
|
||||||
contents:
|
contents:
|
||||||
- src: client/ui/netbird.desktop
|
- src: client/ui/build/netbird.desktop
|
||||||
dst: /usr/share/applications/netbird.desktop
|
dst: /usr/share/applications/netbird.desktop
|
||||||
- src: client/ui/netbird.png
|
- src: client/ui/assets/netbird.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- netbird
|
- netbird
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
|
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
@@ -29,13 +29,13 @@
|
|||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
|
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
</strong>
|
</strong>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
|
<a href="https://github.com/netbirdio/kubernetes-operator">
|
||||||
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
|
New: NetBird Kubernetes Operator
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ type Anonymizer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||||
// 192.51.100.0, 100::
|
// 198.51.100.0, 100::
|
||||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
98
client/cmd/forwarding_rules.go
Normal file
98
client/cmd/forwarding_rules.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var forwardingRulesCmd = &cobra.Command{
|
||||||
|
Use: "forwarding",
|
||||||
|
Short: "List forwarding rules",
|
||||||
|
Long: `Commands to list forwarding rules.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var forwardingRulesListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List forwarding rules",
|
||||||
|
Example: " netbird forwarding list",
|
||||||
|
Long: "Commands to list forwarding rules.",
|
||||||
|
RunE: listForwardingRules,
|
||||||
|
}
|
||||||
|
|
||||||
|
func listForwardingRules(cmd *cobra.Command, _ []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.ForwardingRules(cmd.Context(), &proto.EmptyRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.GetRules()) == 0 {
|
||||||
|
cmd.Println("No forwarding rules available.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
printForwardingRules(cmd, resp.GetRules())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
|
||||||
|
cmd.Println("Available forwarding rules:")
|
||||||
|
|
||||||
|
// Sort rules by translated address
|
||||||
|
sort.Slice(rules, func(i, j int) bool {
|
||||||
|
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
|
||||||
|
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
|
||||||
|
}
|
||||||
|
if rules[i].GetProtocol() != rules[j].GetProtocol() {
|
||||||
|
return rules[i].GetProtocol() < rules[j].GetProtocol()
|
||||||
|
}
|
||||||
|
|
||||||
|
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
|
||||||
|
})
|
||||||
|
|
||||||
|
var lastIP string
|
||||||
|
for _, rule := range rules {
|
||||||
|
dPort := portToString(rule.GetDestinationPort())
|
||||||
|
tPort := portToString(rule.GetTranslatedPort())
|
||||||
|
if lastIP != rule.GetTranslatedAddress() {
|
||||||
|
lastIP = rule.GetTranslatedAddress()
|
||||||
|
cmd.Printf("\nTranslated peer: %s\n", rule.GetTranslatedHostname())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf(" Local %s/%s to %s:%s\n", rule.GetProtocol(), dPort, rule.GetTranslatedAddress(), tPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFirstPort(portInfo *proto.PortInfo) int {
|
||||||
|
switch v := portInfo.PortSelection.(type) {
|
||||||
|
case *proto.PortInfo_Port:
|
||||||
|
return int(v.Port)
|
||||||
|
case *proto.PortInfo_Range_:
|
||||||
|
return int(v.Range.GetStart())
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func portToString(translatedPort *proto.PortInfo) string {
|
||||||
|
switch v := translatedPort.PortSelection.(type) {
|
||||||
|
case *proto.PortInfo_Port:
|
||||||
|
return fmt.Sprintf("%d", v.Port)
|
||||||
|
case *proto.PortInfo_Range_:
|
||||||
|
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
|
||||||
|
default:
|
||||||
|
return "No port specified"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -145,6 +145,7 @@ func init() {
|
|||||||
rootCmd.AddCommand(versionCmd)
|
rootCmd.AddCommand(versionCmd)
|
||||||
rootCmd.AddCommand(sshCmd)
|
rootCmd.AddCommand(sshCmd)
|
||||||
rootCmd.AddCommand(networksCMD)
|
rootCmd.AddCommand(networksCMD)
|
||||||
|
rootCmd.AddCommand(forwardingRulesCmd)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
|
|
||||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||||
@@ -153,6 +154,8 @@ func init() {
|
|||||||
networksCMD.AddCommand(routesListCmd)
|
networksCMD.AddCommand(routesListCmd)
|
||||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||||
|
|
||||||
|
forwardingRulesCmd.AddCommand(forwardingRulesListCmd)
|
||||||
|
|
||||||
debugCmd.AddCommand(debugBundleCmd)
|
debugCmd.AddCommand(debugBundleCmd)
|
||||||
debugCmd.AddCommand(logCmd)
|
debugCmd.AddCommand(logCmd)
|
||||||
logCmd.AddCommand(logLevelCmd)
|
logCmd.AddCommand(logLevelCmd)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
@@ -89,7 +90,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
|||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,10 +134,11 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
// TODO: make after-startup backoff err available
|
// TODO: make after-startup backoff err available
|
||||||
run := make(chan error, 1)
|
run := make(chan struct{}, 1)
|
||||||
|
clientErr := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
if err := client.Run(run); err != nil {
|
if err := client.Run(run); err != nil {
|
||||||
run <- err
|
clientErr <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -147,13 +148,9 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||||
}
|
}
|
||||||
return startCtx.Err()
|
return startCtx.Err()
|
||||||
case err := <-run:
|
case err := <-clientErr:
|
||||||
if err != nil {
|
return fmt.Errorf("startup: %w", err)
|
||||||
if stopErr := client.Stop(); stopErr != nil {
|
case <-run:
|
||||||
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("startup: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.connect = client
|
c.connect = client
|
||||||
|
|||||||
@@ -4,12 +4,13 @@ import (
|
|||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() device.WGAddress
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
|||||||
@@ -30,10 +30,8 @@ type entry struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type aclManager struct {
|
type aclManager struct {
|
||||||
iptablesClient *iptables.IPTables
|
iptablesClient *iptables.IPTables
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
routingFwChainName string
|
|
||||||
|
|
||||||
entries aclEntries
|
entries aclEntries
|
||||||
optionalEntries map[string][]entry
|
optionalEntries map[string][]entry
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
@@ -41,12 +39,10 @@ type aclManager struct {
|
|||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||||
m := &aclManager{
|
m := &aclManager{
|
||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
routingFwChainName: routingFwChainName,
|
|
||||||
|
|
||||||
entries: make(map[string][][]string),
|
entries: make(map[string][][]string),
|
||||||
optionalEntries: make(map[string][]entry),
|
optionalEntries: make(map[string][]entry),
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
@@ -314,9 +310,12 @@ func (m *aclManager) seedInitialEntries() {
|
|||||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||||
|
|
||||||
|
// Inbound is handled by our ACLs, the rest is dropped.
|
||||||
|
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
|
||||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||||
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) seedInitialOptionalEntries() {
|
func (m *aclManager) seedInitialOptionalEntries() {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ type Manager struct {
|
|||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMapper interface {
|
type iFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
return nil, fmt.Errorf("create router: %w", err)
|
return nil, fmt.Errorf("create router: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
}
|
}
|
||||||
@@ -166,7 +166,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@@ -226,6 +226,22 @@ func (m *Manager) DisableRouting() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,15 +10,15 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ifaceMock = &iFaceMock{
|
var ifaceMock = &iFaceMock{
|
||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
|
|||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMock struct {
|
type iFaceMock struct {
|
||||||
NameFunc func() string
|
NameFunc func() string
|
||||||
AddressFunc func() iface.WGAddress
|
AddressFunc func() wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) Name() string {
|
func (i *iFaceMock) Name() string {
|
||||||
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
|
|||||||
panic("NameFunc is not set")
|
panic("NameFunc is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) Address() iface.WGAddress {
|
func (i *iFaceMock) Address() wgaddr.Address {
|
||||||
if i.AddressFunc != nil {
|
if i.AddressFunc != nil {
|
||||||
return i.AddressFunc()
|
return i.AddressFunc()
|
||||||
}
|
}
|
||||||
@@ -62,7 +62,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset(nil)
|
err := manager.Close(nil)
|
||||||
require.NoError(t, err, "clear the manager state")
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -100,14 +100,14 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Reset(nil)
|
err = manager.Close(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
|
|
||||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||||
require.NoError(t, err, "failed check chain exists")
|
require.NoError(t, err, "failed check chain exists")
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
|
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
@@ -136,7 +136,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset(nil)
|
err := manager.Close(nil)
|
||||||
require.NoError(t, err, "clear the manager state")
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -166,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
err = manager.Reset(nil)
|
err = manager.Close(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
@@ -204,7 +204,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset(nil)
|
err := manager.Close(nil)
|
||||||
require.NoError(t, err, "clear the manager state")
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
@@ -23,22 +24,36 @@ import (
|
|||||||
|
|
||||||
// constants needed to manage and create iptable rules
|
// constants needed to manage and create iptable rules
|
||||||
const (
|
const (
|
||||||
tableFilter = "filter"
|
tableFilter = "filter"
|
||||||
tableNat = "nat"
|
tableNat = "nat"
|
||||||
tableMangle = "mangle"
|
tableMangle = "mangle"
|
||||||
|
|
||||||
chainPOSTROUTING = "POSTROUTING"
|
chainPOSTROUTING = "POSTROUTING"
|
||||||
chainPREROUTING = "PREROUTING"
|
chainPREROUTING = "PREROUTING"
|
||||||
chainRTNAT = "NETBIRD-RT-NAT"
|
chainRTNAT = "NETBIRD-RT-NAT"
|
||||||
chainRTFWD = "NETBIRD-RT-FWD"
|
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
|
||||||
|
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||||
chainRTPRE = "NETBIRD-RT-PRE"
|
chainRTPRE = "NETBIRD-RT-PRE"
|
||||||
|
chainRTRDR = "NETBIRD-RT-RDR"
|
||||||
routingFinalForwardJump = "ACCEPT"
|
routingFinalForwardJump = "ACCEPT"
|
||||||
routingFinalNatJump = "MASQUERADE"
|
routingFinalNatJump = "MASQUERADE"
|
||||||
|
|
||||||
jumpPre = "jump-pre"
|
jumpManglePre = "jump-mangle-pre"
|
||||||
jumpNat = "jump-nat"
|
jumpNatPre = "jump-nat-pre"
|
||||||
matchSet = "--match-set"
|
jumpNatPost = "jump-nat-post"
|
||||||
|
matchSet = "--match-set"
|
||||||
|
|
||||||
|
dnatSuffix = "_dnat"
|
||||||
|
snatSuffix = "_snat"
|
||||||
|
fwdSuffix = "_fwd"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ruleInfo struct {
|
||||||
|
chain string
|
||||||
|
table string
|
||||||
|
rule []string
|
||||||
|
}
|
||||||
|
|
||||||
type routeFilteringRuleParams struct {
|
type routeFilteringRuleParams struct {
|
||||||
Sources []netip.Prefix
|
Sources []netip.Prefix
|
||||||
Destination netip.Prefix
|
Destination netip.Prefix
|
||||||
@@ -62,6 +77,7 @@ type router struct {
|
|||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
|
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||||
@@ -69,6 +85,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
|
|||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
rules: make(map[string][]string),
|
rules: make(map[string][]string),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
}
|
}
|
||||||
|
|
||||||
r.ipsetCounter = refcounter.New(
|
r.ipsetCounter = refcounter.New(
|
||||||
@@ -139,9 +156,9 @@ func (r *router) AddRouteFiltering(
|
|||||||
var err error
|
var err error
|
||||||
if action == firewall.ActionDrop {
|
if action == firewall.ActionDrop {
|
||||||
// after the established rule
|
// after the established rule
|
||||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
|
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
||||||
} else {
|
} else {
|
||||||
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
|
err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -156,12 +173,12 @@ func (r *router) AddRouteFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
ruleKey := rule.GetRuleID()
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
setName := r.findSetNameInRule(rule)
|
setName := r.findSetNameInRule(rule)
|
||||||
|
|
||||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
return fmt.Errorf("delete route rule: %v", err)
|
return fmt.Errorf("delete route rule: %v", err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
@@ -212,6 +229,10 @@ func (r *router) deleteIpSet(setName string) error {
|
|||||||
|
|
||||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if r.legacyManagement {
|
if r.legacyManagement {
|
||||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||||
@@ -238,6 +259,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove nat rule: %w", err)
|
return fmt.Errorf("remove nat rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -264,7 +289,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,7 +302,7 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
@@ -305,7 +330,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
delete(r.rules, k)
|
delete(r.rules, k)
|
||||||
@@ -343,9 +368,11 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
chain string
|
chain string
|
||||||
table string
|
table string
|
||||||
}{
|
}{
|
||||||
{chainRTFWD, tableFilter},
|
{chainRTFWDIN, tableFilter},
|
||||||
{chainRTNAT, tableNat},
|
{chainRTFWDOUT, tableFilter},
|
||||||
{chainRTPRE, tableMangle},
|
{chainRTPRE, tableMangle},
|
||||||
|
{chainRTNAT, tableNat},
|
||||||
|
{chainRTRDR, tableNat},
|
||||||
} {
|
} {
|
||||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -365,16 +392,22 @@ func (r *router) createContainers() error {
|
|||||||
chain string
|
chain string
|
||||||
table string
|
table string
|
||||||
}{
|
}{
|
||||||
{chainRTFWD, tableFilter},
|
{chainRTFWDIN, tableFilter},
|
||||||
|
{chainRTFWDOUT, tableFilter},
|
||||||
{chainRTPRE, tableMangle},
|
{chainRTPRE, tableMangle},
|
||||||
{chainRTNAT, tableNat},
|
{chainRTNAT, tableNat},
|
||||||
|
{chainRTRDR, tableNat},
|
||||||
} {
|
} {
|
||||||
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
|
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
|
||||||
|
return fmt.Errorf("insert established rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
|
||||||
return fmt.Errorf("insert established rule: %w", err)
|
return fmt.Errorf("insert established rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -415,27 +448,6 @@ func (r *router) addPostroutingRules() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createAndSetupChain(chain string) error {
|
|
||||||
table := r.getTableForChain(chain)
|
|
||||||
|
|
||||||
if err := r.iptablesClient.NewChain(table, chain); err != nil {
|
|
||||||
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) getTableForChain(chain string) string {
|
|
||||||
switch chain {
|
|
||||||
case chainRTNAT:
|
|
||||||
return tableNat
|
|
||||||
case chainRTPRE:
|
|
||||||
return tableMangle
|
|
||||||
default:
|
|
||||||
return tableFilter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) insertEstablishedRule(chain string) error {
|
func (r *router) insertEstablishedRule(chain string) error {
|
||||||
establishedRule := getConntrackEstablished()
|
establishedRule := getConntrackEstablished()
|
||||||
|
|
||||||
@@ -454,28 +466,43 @@ func (r *router) addJumpRules() error {
|
|||||||
// Jump to NAT chain
|
// Jump to NAT chain
|
||||||
natRule := []string{"-j", chainRTNAT}
|
natRule := []string{"-j", chainRTNAT}
|
||||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||||
return fmt.Errorf("add nat jump rule: %v", err)
|
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
||||||
}
|
}
|
||||||
r.rules[jumpNat] = natRule
|
r.rules[jumpNatPost] = natRule
|
||||||
|
|
||||||
// Jump to prerouting chain
|
// Jump to mangle prerouting chain
|
||||||
preRule := []string{"-j", chainRTPRE}
|
preRule := []string{"-j", chainRTPRE}
|
||||||
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
||||||
return fmt.Errorf("add prerouting jump rule: %v", err)
|
return fmt.Errorf("add mangle prerouting jump rule: %v", err)
|
||||||
}
|
}
|
||||||
r.rules[jumpPre] = preRule
|
r.rules[jumpManglePre] = preRule
|
||||||
|
|
||||||
|
// Jump to nat prerouting chain
|
||||||
|
rdrRule := []string{"-j", chainRTRDR}
|
||||||
|
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
|
||||||
|
return fmt.Errorf("add nat prerouting jump rule: %v", err)
|
||||||
|
}
|
||||||
|
r.rules[jumpNatPre] = rdrRule
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) cleanJumpRules() error {
|
func (r *router) cleanJumpRules() error {
|
||||||
for _, ruleKey := range []string{jumpNat, jumpPre} {
|
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
table := tableNat
|
var table, chain string
|
||||||
chain := chainPOSTROUTING
|
switch ruleKey {
|
||||||
if ruleKey == jumpPre {
|
case jumpNatPost:
|
||||||
|
table = tableNat
|
||||||
|
chain = chainPOSTROUTING
|
||||||
|
case jumpManglePre:
|
||||||
table = tableMangle
|
table = tableMangle
|
||||||
chain = chainPREROUTING
|
chain = chainPREROUTING
|
||||||
|
case jumpNatPre:
|
||||||
|
table = tableNat
|
||||||
|
chain = chainPREROUTING
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown jump rule: %s", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
||||||
@@ -520,6 +547,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.rules[ruleKey] = rule
|
r.rules[ruleKey] = rule
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -535,6 +564,7 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
log.Debugf("marking rule %s not found", ruleKey)
|
log.Debugf("marking rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -564,6 +594,137 @@ func (r *router) updateState() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
toDestination := rule.TranslatedAddress.String()
|
||||||
|
switch {
|
||||||
|
case len(rule.TranslatedPort.Values) == 0:
|
||||||
|
// no translated port, use original port
|
||||||
|
case len(rule.TranslatedPort.Values) == 1:
|
||||||
|
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
|
||||||
|
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||||
|
// need the "/originalport" suffix to avoid dnat port randomization
|
||||||
|
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := strings.ToLower(string(rule.Protocol))
|
||||||
|
|
||||||
|
rules := make(map[string]ruleInfo, 3)
|
||||||
|
|
||||||
|
// DNAT rule
|
||||||
|
dnatRule := []string{
|
||||||
|
"!", "-i", r.wgIface.Name(),
|
||||||
|
"-p", proto,
|
||||||
|
"-j", "DNAT",
|
||||||
|
"--to-destination", toDestination,
|
||||||
|
}
|
||||||
|
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
|
||||||
|
rules[ruleKey+dnatSuffix] = ruleInfo{
|
||||||
|
table: tableNat,
|
||||||
|
chain: chainRTRDR,
|
||||||
|
rule: dnatRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
// SNAT rule
|
||||||
|
snatRule := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-p", proto,
|
||||||
|
"-d", rule.TranslatedAddress.String(),
|
||||||
|
"-j", "MASQUERADE",
|
||||||
|
}
|
||||||
|
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||||
|
rules[ruleKey+snatSuffix] = ruleInfo{
|
||||||
|
table: tableNat,
|
||||||
|
chain: chainRTNAT,
|
||||||
|
rule: snatRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward filtering rule, if fwd policy is DROP
|
||||||
|
forwardRule := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-p", proto,
|
||||||
|
"-d", rule.TranslatedAddress.String(),
|
||||||
|
"-j", "ACCEPT",
|
||||||
|
}
|
||||||
|
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||||
|
rules[ruleKey+fwdSuffix] = ruleInfo{
|
||||||
|
table: tableFilter,
|
||||||
|
chain: chainRTFWDOUT,
|
||||||
|
rule: forwardRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, ruleInfo := range rules {
|
||||||
|
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||||
|
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
||||||
|
log.Errorf("rollback failed: %v", rollbackErr)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("add rule %s: %w", key, err)
|
||||||
|
}
|
||||||
|
r.rules[key] = ruleInfo.rule
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) rollbackRules(rules map[string]ruleInfo) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for key, ruleInfo := range rules {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
|
||||||
|
// On rollback error, add to rules map for next cleanup
|
||||||
|
r.rules[key] = ruleInfo.rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if merr != nil {
|
||||||
|
r.updateState()
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
|
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey+dnatSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||||
|
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey+snatSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
||||||
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey+fwdSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||||
var rule []string
|
var rule []string
|
||||||
|
|
||||||
|
|||||||
@@ -39,12 +39,14 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Now 5 rules:
|
// Now 5 rules:
|
||||||
// 1. established rule in forward chain
|
// 1. established rule forward in
|
||||||
// 2. jump rule to NAT chain
|
// 2. estbalished rule forward out
|
||||||
// 3. jump rule to PRE chain
|
// 3. jump rule to POST nat chain
|
||||||
// 4. static outbound masquerade rule
|
// 4. jump rule to PRE mangle chain
|
||||||
// 5. static return masquerade rule
|
// 5. jump rule to PRE nat chain
|
||||||
require.Len(t, manager.rules, 5, "should have created rules map")
|
// 6. static outbound masquerade rule
|
||||||
|
// 7. static return masquerade rule
|
||||||
|
require.Len(t, manager.rules, 7, "should have created rules map")
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||||
@@ -332,14 +334,14 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
// Check if the rule is in the internal map
|
||||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
rule, ok := r.rules[ruleKey.ID()]
|
||||||
assert.True(t, ok, "Rule not found in internal map")
|
assert.True(t, ok, "Rule not found in internal map")
|
||||||
|
|
||||||
// Log the internal rule
|
// Log the internal rule
|
||||||
t.Logf("Internal rule: %v", rule)
|
t.Logf("Internal rule: %v", rule)
|
||||||
|
|
||||||
// Check if the rule exists in iptables
|
// Check if the rule exists in iptables
|
||||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
|
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
|
||||||
assert.NoError(t, err, "Failed to check rule existence")
|
assert.NoError(t, err, "Failed to check rule existence")
|
||||||
assert.True(t, exists, "Rule not found in iptables")
|
assert.True(t, exists, "Rule not found in iptables")
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,6 @@ type Rule struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) ID() string {
|
||||||
return r.ruleID
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,21 +4,20 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type InterfaceState struct {
|
type InterfaceState struct {
|
||||||
NameStr string `json:"name"`
|
NameStr string `json:"name"`
|
||||||
WGAddress iface.WGAddress `json:"wg_address"`
|
WGAddress wgaddr.Address `json:"wg_address"`
|
||||||
UserspaceBind bool `json:"userspace_bind"`
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Name() string {
|
func (i *InterfaceState) Name() string {
|
||||||
return i.NameStr
|
return i.NameStr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Address() device.WGAddress {
|
func (i *InterfaceState) Address() wgaddr.Address {
|
||||||
return i.WGAddress
|
return i.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +61,7 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ipt.Reset(nil); err != nil {
|
if err := ipt.Close(nil); err != nil {
|
||||||
return fmt.Errorf("reset iptables manager: %w", err)
|
return fmt.Errorf("reset iptables manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ const (
|
|||||||
// Each firewall type for different OS can use different type
|
// Each firewall type for different OS can use different type
|
||||||
// of the properties to hold data of the created rule
|
// of the properties to hold data of the created rule
|
||||||
type Rule interface {
|
type Rule interface {
|
||||||
// GetRuleID returns the rule id
|
// ID returns the rule id
|
||||||
GetRuleID() string
|
ID() string
|
||||||
}
|
}
|
||||||
|
|
||||||
// RuleDirection is the traffic direction which a rule is applied
|
// RuleDirection is the traffic direction which a rule is applied
|
||||||
@@ -94,8 +94,8 @@ type Manager interface {
|
|||||||
// SetLegacyManagement sets the legacy management mode
|
// SetLegacyManagement sets the legacy management mode
|
||||||
SetLegacyManagement(legacy bool) error
|
SetLegacyManagement(legacy bool) error
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Close closes the firewall manager
|
||||||
Reset(stateManager *statemanager.Manager) error
|
Close(stateManager *statemanager.Manager) error
|
||||||
|
|
||||||
// Flush the changes to firewall controller
|
// Flush the changes to firewall controller
|
||||||
Flush() error
|
Flush() error
|
||||||
@@ -105,6 +105,12 @@ type Manager interface {
|
|||||||
EnableRouting() error
|
EnableRouting() error
|
||||||
|
|
||||||
DisableRouting() error
|
DisableRouting() error
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule
|
||||||
|
AddDNATRule(ForwardRule) (Rule, error)
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
|
DeleteDNATRule(Rule) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, pair RouterPair) string {
|
func GenKey(format string, pair RouterPair) string {
|
||||||
|
|||||||
27
client/firewall/manager/forward_rule.go
Normal file
27
client/firewall/manager/forward_rule.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ForwardRule todo figure out better place to this to avoid circular imports
|
||||||
|
type ForwardRule struct {
|
||||||
|
Protocol Protocol
|
||||||
|
DestinationPort Port
|
||||||
|
TranslatedAddress netip.Addr
|
||||||
|
TranslatedPort Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ForwardRule) ID() string {
|
||||||
|
id := fmt.Sprintf("%s;%s;%s;%s",
|
||||||
|
r.Protocol,
|
||||||
|
r.DestinationPort.String(),
|
||||||
|
r.TranslatedAddress.String(),
|
||||||
|
r.TranslatedPort.String())
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ForwardRule) String() string {
|
||||||
|
return fmt.Sprintf("protocol: %s, destinationPort: %s, translatedAddress: %s, translatedPort: %s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String())
|
||||||
|
}
|
||||||
@@ -1,30 +1,12 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Protocol is the protocol of the port
|
|
||||||
type Protocol string
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ProtocolTCP is the TCP protocol
|
|
||||||
ProtocolTCP Protocol = "tcp"
|
|
||||||
|
|
||||||
// ProtocolUDP is the UDP protocol
|
|
||||||
ProtocolUDP Protocol = "udp"
|
|
||||||
|
|
||||||
// ProtocolICMP is the ICMP protocol
|
|
||||||
ProtocolICMP Protocol = "icmp"
|
|
||||||
|
|
||||||
// ProtocolALL cover all supported protocols
|
|
||||||
ProtocolALL Protocol = "all"
|
|
||||||
|
|
||||||
// ProtocolUnknown unknown protocol
|
|
||||||
ProtocolUnknown Protocol = "unknown"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Port of the address for firewall rule
|
// Port of the address for firewall rule
|
||||||
|
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
|
||||||
type Port struct {
|
type Port struct {
|
||||||
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
||||||
IsRange bool
|
IsRange bool
|
||||||
@@ -33,6 +15,25 @@ type Port struct {
|
|||||||
Values []uint16
|
Values []uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewPort(ports ...int) (*Port, error) {
|
||||||
|
if len(ports) == 0 {
|
||||||
|
return nil, fmt.Errorf("no port provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
ports16 := make([]uint16, len(ports))
|
||||||
|
for i, port := range ports {
|
||||||
|
if port < 1 || port > 65535 {
|
||||||
|
return nil, fmt.Errorf("invalid port number: %d (must be between 1-65535)", port)
|
||||||
|
}
|
||||||
|
ports16[i] = uint16(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Port{
|
||||||
|
IsRange: len(ports) > 1,
|
||||||
|
Values: ports16,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// String interface implementation
|
// String interface implementation
|
||||||
func (p *Port) String() string {
|
func (p *Port) String() string {
|
||||||
var ports string
|
var ports string
|
||||||
|
|||||||
19
client/firewall/manager/protocol.go
Normal file
19
client/firewall/manager/protocol.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
// Protocol is the protocol of the port
|
||||||
|
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
|
||||||
|
type Protocol string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ProtocolTCP is the TCP protocol
|
||||||
|
ProtocolTCP Protocol = "tcp"
|
||||||
|
|
||||||
|
// ProtocolUDP is the UDP protocol
|
||||||
|
ProtocolUDP Protocol = "udp"
|
||||||
|
|
||||||
|
// ProtocolICMP is the ICMP protocol
|
||||||
|
ProtocolICMP Protocol = "icmp"
|
||||||
|
|
||||||
|
// ProtocolALL cover all supported protocols
|
||||||
|
ProtocolALL Protocol = "all"
|
||||||
|
)
|
||||||
@@ -127,7 +127,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
log.Errorf("failed to delete mangle rule: %v", err)
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(m.rules, r.GetRuleID())
|
delete(m.rules, r.ID())
|
||||||
return m.rConn.Flush()
|
return m.rConn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,7 +141,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
log.Errorf("failed to delete mangle rule: %v", err)
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(m.rules, r.GetRuleID())
|
delete(m.rules, r.ID())
|
||||||
return m.rConn.Flush()
|
return m.rConn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,7 +176,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(m.rules, r.GetRuleID())
|
delete(m.rules, r.ID())
|
||||||
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
||||||
|
|
||||||
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ const (
|
|||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMapper interface {
|
type iFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
// We only need to record minimal interface state for potential recreation.
|
// We only need to record minimal interface state for potential recreation.
|
||||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||||
// cleanup using Reset() without needing to store specific rules.
|
// cleanup using Close() without needing to store specific rules.
|
||||||
if err := stateManager.UpdateState(&ShutdownState{
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
@@ -242,7 +242,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@@ -342,6 +342,22 @@ func (m *Manager) Flush() error {
|
|||||||
return m.aclManager.Flush()
|
return m.aclManager.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -16,15 +16,15 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ifaceMock = &iFaceMock{
|
var ifaceMock = &iFaceMock{
|
||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: net.ParseIP("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
IP: net.ParseIP("100.96.0.0"),
|
||||||
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
|
|||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMock struct {
|
type iFaceMock struct {
|
||||||
NameFunc func() string
|
NameFunc func() string
|
||||||
AddressFunc func() iface.WGAddress
|
AddressFunc func() wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) Name() string {
|
func (i *iFaceMock) Name() string {
|
||||||
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
|
|||||||
panic("NameFunc is not set")
|
panic("NameFunc is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) Address() iface.WGAddress {
|
func (i *iFaceMock) Address() wgaddr.Address {
|
||||||
if i.AddressFunc != nil {
|
if i.AddressFunc != nil {
|
||||||
return i.AddressFunc()
|
return i.AddressFunc()
|
||||||
}
|
}
|
||||||
@@ -65,7 +65,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = manager.Reset(nil)
|
err = manager.Close(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
@@ -162,7 +162,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
// established rule remains
|
// established rule remains
|
||||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||||
|
|
||||||
err = manager.Reset(nil)
|
err = manager.Close(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: net.ParseIP("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
IP: net.ParseIP("100.96.0.0"),
|
||||||
@@ -191,7 +191,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(nil); err != nil {
|
if err := manager.Close(nil); err != nil {
|
||||||
t.Errorf("clear the manager state: %v", err)
|
t.Errorf("clear the manager state: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -274,7 +274,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
err := manager.Reset(nil)
|
err := manager.Close(nil)
|
||||||
require.NoError(t, err, "failed to reset manager state")
|
require.NoError(t, err, "failed to reset manager state")
|
||||||
|
|
||||||
// Verify iptables output after reset
|
// Verify iptables output after reset
|
||||||
|
|||||||
@@ -14,23 +14,31 @@ import (
|
|||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
|
"github.com/google/nftables/xt"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
chainNameRoutingFw = "netbird-rt-fwd"
|
tableNat = "nat"
|
||||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
chainNameNatPrerouting = "PREROUTING"
|
||||||
chainNameForward = "FORWARD"
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
|
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||||
|
chainNameForward = "FORWARD"
|
||||||
|
|
||||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||||
|
|
||||||
|
dnatSuffix = "_dnat"
|
||||||
|
snatSuffix = "_snat"
|
||||||
)
|
)
|
||||||
|
|
||||||
const refreshRulesMapError = "refresh rules map: %w"
|
const refreshRulesMapError = "refresh rules map: %w"
|
||||||
@@ -49,16 +57,18 @@ type router struct {
|
|||||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||||
r := &router{
|
r := &router{
|
||||||
conn: &nftables.Conn{},
|
conn: &nftables.Conn{},
|
||||||
workTable: workTable,
|
workTable: workTable,
|
||||||
chains: make(map[string]*nftables.Chain),
|
chains: make(map[string]*nftables.Chain),
|
||||||
rules: make(map[string]*nftables.Rule),
|
rules: make(map[string]*nftables.Rule),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
}
|
}
|
||||||
|
|
||||||
r.ipsetCounter = refcounter.New(
|
r.ipsetCounter = refcounter.New(
|
||||||
@@ -98,7 +108,52 @@ func (r *router) Reset() error {
|
|||||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||||
r.ipsetCounter.Clear()
|
r.ipsetCounter.Clear()
|
||||||
|
|
||||||
return r.removeAcceptForwardRules()
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.removeAcceptForwardRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeNatPreroutingRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeNatPreroutingRules() error {
|
||||||
|
table := &nftables.Table{
|
||||||
|
Name: tableNat,
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
}
|
||||||
|
chain := &nftables.Chain{
|
||||||
|
Name: chainNameNatPrerouting,
|
||||||
|
Table: table,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
}
|
||||||
|
rules, err := r.conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get rules from nat table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// Delete rules that have our UserData suffix
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), dnatSuffix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
@@ -133,14 +188,22 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeNAT,
|
Type: nftables.ChainTypeNAT,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRoutingRdr,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
})
|
||||||
|
|
||||||
// Chain is created by acl manager
|
// Chain is created by acl manager
|
||||||
// TODO: move creation to a common place
|
// TODO: move creation to a common place
|
||||||
r.chains[chainNamePrerouting] = &nftables.Chain{
|
r.chains[chainNamePrerouting] = &nftables.Chain{
|
||||||
Name: chainNamePrerouting,
|
Name: chainNamePrerouting,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
Priority: nftables.ChainPriorityMangle,
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the single NAT rule that matches on mark
|
// Add the single NAT rule that matches on mark
|
||||||
@@ -281,7 +344,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleKey := rule.GetRuleID()
|
ruleKey := rule.ID()
|
||||||
nftRule, exists := r.rules[ruleKey]
|
nftRule, exists := r.rules[ruleKey]
|
||||||
if !exists {
|
if !exists {
|
||||||
log.Debugf("route rule %s not found", ruleKey)
|
log.Debugf("route rule %s not found", ruleKey)
|
||||||
@@ -410,6 +473,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
|||||||
|
|
||||||
// AddNatRule appends a nftables rule pair to the nat chain
|
// AddNatRule appends a nftables rule pair to the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
@@ -836,6 +903,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
|||||||
|
|
||||||
// RemoveNatRule removes the prerouting mark rule
|
// RemoveNatRule removes the prerouting mark rule
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
@@ -896,6 +967,269 @@ func (r *router) refreshRulesMap() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
protoNum, err := protoToInt(rule.Protocol)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.addDnatMasq(rule, protoNum, ruleKey)
|
||||||
|
|
||||||
|
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
|
||||||
|
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
|
||||||
|
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
|
||||||
|
// TODO: find chains with drop policies and add rules there
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey string) error {
|
||||||
|
dnatExprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{protoNum},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseTransportHeader,
|
||||||
|
Offset: 2,
|
||||||
|
Len: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
|
||||||
|
|
||||||
|
// shifted translated port is not supported in nftables, so we hand this over to xtables
|
||||||
|
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
|
||||||
|
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
|
||||||
|
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
|
||||||
|
return r.addXTablesRedirect(dnatExprs, ruleKey, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dnatExprs = append(dnatExprs, additionalExprs...)
|
||||||
|
|
||||||
|
dnatExprs = append(dnatExprs,
|
||||||
|
&expr.NAT{
|
||||||
|
Type: expr.NATTypeDestNAT,
|
||||||
|
Family: uint32(nftables.TableFamilyIPv4),
|
||||||
|
RegAddrMin: 1,
|
||||||
|
RegProtoMin: regProtoMin,
|
||||||
|
RegProtoMax: regProtoMax,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
dnatRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingRdr],
|
||||||
|
Exprs: dnatExprs,
|
||||||
|
UserData: []byte(ruleKey + dnatSuffix),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(dnatRule)
|
||||||
|
r.rules[ruleKey+dnatSuffix] = dnatRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
switch {
|
||||||
|
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||||
|
return r.handlePortRange(rule)
|
||||||
|
case len(rule.TranslatedPort.Values) == 0:
|
||||||
|
return r.handleAddressOnly(rule)
|
||||||
|
case len(rule.TranslatedPort.Values) == 1:
|
||||||
|
return r.handleSinglePort(rule)
|
||||||
|
default:
|
||||||
|
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 3,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return exprs, 2, 3, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return exprs, 0, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return exprs, 2, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule firewall.ForwardRule) error {
|
||||||
|
dnatExprs = append(dnatExprs,
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Target{
|
||||||
|
Name: "DNAT",
|
||||||
|
Rev: 2,
|
||||||
|
Info: &xt.NatRange2{
|
||||||
|
NatRange: xt.NatRange{
|
||||||
|
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
|
||||||
|
MinIP: rule.TranslatedAddress.AsSlice(),
|
||||||
|
MaxIP: rule.TranslatedAddress.AsSlice(),
|
||||||
|
MinPort: rule.TranslatedPort.Values[0],
|
||||||
|
MaxPort: rule.TranslatedPort.Values[1],
|
||||||
|
},
|
||||||
|
BasePort: rule.DestinationPort.Values[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
dnatRule := &nftables.Rule{
|
||||||
|
Table: &nftables.Table{
|
||||||
|
Name: tableNat,
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
},
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: chainNameNatPrerouting,
|
||||||
|
Table: r.filterTable,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
},
|
||||||
|
Exprs: dnatExprs,
|
||||||
|
UserData: []byte(ruleKey + dnatSuffix),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(dnatRule)
|
||||||
|
r.rules[ruleKey+dnatSuffix] = dnatRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey string) {
|
||||||
|
masqExprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{protoNum},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 16,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
|
||||||
|
masqExprs = append(masqExprs, &expr.Masq{})
|
||||||
|
|
||||||
|
masqRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
|
Exprs: masqExprs,
|
||||||
|
UserData: []byte(ruleKey + snatSuffix),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(masqRule)
|
||||||
|
r.rules[ruleKey+snatSuffix] = masqRule
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
|
if err := r.conn.DelRule(dnatRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||||
|
if err := r.conn.DelRule(masqRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if merr == nil {
|
||||||
|
delete(r.rules, ruleKey+dnatSuffix)
|
||||||
|
delete(r.rules, ruleKey+snatSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||||
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
||||||
var offset uint32
|
var offset uint32
|
||||||
@@ -959,15 +1293,11 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
|||||||
if port.IsRange && len(port.Values) == 2 {
|
if port.IsRange && len(port.Values) == 2 {
|
||||||
// Handle port range
|
// Handle port range
|
||||||
exprs = append(exprs,
|
exprs = append(exprs,
|
||||||
&expr.Cmp{
|
&expr.Range{
|
||||||
Op: expr.CmpOpGte,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||||
},
|
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpLte,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
// need fw manager to init both acl mgr and router for all chains to be present
|
// need fw manager to init both acl mgr and router for all chains to be present
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Reset(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
@@ -127,7 +127,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Reset(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
@@ -319,7 +319,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
// Check if the rule is in the internal map
|
||||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
rule, ok := r.rules[ruleKey.ID()]
|
||||||
assert.True(t, ok, "Rule not found in internal map")
|
assert.True(t, ok, "Rule not found in internal map")
|
||||||
|
|
||||||
t.Log("Internal rule expressions:")
|
t.Log("Internal rule expressions:")
|
||||||
@@ -336,7 +336,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
var nftRule *nftables.Rule
|
var nftRule *nftables.Rule
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if string(rule.UserData) == ruleKey.GetRuleID() {
|
if string(rule.UserData) == ruleKey.ID() {
|
||||||
nftRule = rule
|
nftRule = rule
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -595,16 +595,20 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
|||||||
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
|
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
|
||||||
payloadFound = true
|
payloadFound = true
|
||||||
}
|
}
|
||||||
case *expr.Cmp:
|
case *expr.Range:
|
||||||
if port.IsRange {
|
if port.IsRange && len(port.Values) == 2 {
|
||||||
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
|
fromPort := binary.BigEndian.Uint16(ex.FromData)
|
||||||
|
toPort := binary.BigEndian.Uint16(ex.ToData)
|
||||||
|
if fromPort == port.Values[0] && toPort == port.Values[1] {
|
||||||
portMatchFound = true
|
portMatchFound = true
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
|
case *expr.Cmp:
|
||||||
|
if !port.IsRange {
|
||||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
|
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
|
||||||
portValue := binary.BigEndian.Uint16(ex.Data)
|
portValue := binary.BigEndian.Uint16(ex.Data)
|
||||||
for _, p := range port.Values {
|
for _, p := range port.Values {
|
||||||
if uint16(p) == portValue {
|
if p == portValue {
|
||||||
portMatchFound = true
|
portMatchFound = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,6 @@ type Rule struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) ID() string {
|
||||||
return r.ruleID
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,21 +3,20 @@ package nftables
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type InterfaceState struct {
|
type InterfaceState struct {
|
||||||
NameStr string `json:"name"`
|
NameStr string `json:"name"`
|
||||||
WGAddress iface.WGAddress `json:"wg_address"`
|
WGAddress wgaddr.Address `json:"wg_address"`
|
||||||
UserspaceBind bool `json:"userspace_bind"`
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Name() string {
|
func (i *InterfaceState) Name() string {
|
||||||
return i.NameStr
|
return i.NameStr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Address() device.WGAddress {
|
func (i *InterfaceState) Address() wgaddr.Address {
|
||||||
return i.WGAddress
|
return i.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,7 +38,7 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
return fmt.Errorf("create nftables manager: %w", err)
|
return fmt.Errorf("create nftables manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := nft.Reset(nil); err != nil {
|
if err := nft.Close(nil); err != nil {
|
||||||
return fmt.Errorf("reset nftables manager: %w", err)
|
return fmt.Errorf("reset nftables manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,12 +8,11 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@@ -22,17 +21,14 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
|||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if m.forwarder != nil {
|
||||||
@@ -48,7 +44,7 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Reset(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,8 +20,8 @@ const (
|
|||||||
firewallRuleName = "Netbird"
|
firewallRuleName = "Netbird"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Close closes the firewall manager
|
||||||
func (m *Manager) Reset(*statemanager.Manager) error {
|
func (m *Manager) Close(*statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@@ -31,17 +30,14 @@ func (m *Manager) Reset(*statemanager.Manager) error {
|
|||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if m.forwarder != nil {
|
||||||
|
|||||||
@@ -3,14 +3,14 @@ package common
|
|||||||
import (
|
import (
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
GetWGDevice() *wgdevice.Device
|
GetWGDevice() *wgdevice.Device
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -39,8 +40,8 @@ type ICMPTracker struct {
|
|||||||
connections map[ICMPConnKey]*ICMPConnTrack
|
connections map[ICMPConnKey]*ICMPConnTrack
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
|
tickerCancel context.CancelFunc
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
done chan struct{}
|
|
||||||
ipPool *PreallocatedIPs
|
ipPool *PreallocatedIPs
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,16 +51,18 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
|||||||
timeout = DefaultICMPTimeout
|
timeout = DefaultICMPTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
tracker := &ICMPTracker{
|
tracker := &ICMPTracker{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||||
done: make(chan struct{}),
|
tickerCancel: cancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
ipPool: NewPreallocatedIPs(),
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine()
|
go tracker.cleanupRoutine(ctx)
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,12 +122,14 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
|
|||||||
conn.Sequence == seq
|
conn.Sequence == seq
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) cleanupRoutine() {
|
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||||
|
defer t.tickerCancel()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.cleanupTicker.C:
|
case <-t.cleanupTicker.C:
|
||||||
t.cleanup()
|
t.cleanup()
|
||||||
case <-t.done:
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -146,8 +151,7 @@ func (t *ICMPTracker) cleanup() {
|
|||||||
|
|
||||||
// Close stops the cleanup routine and releases resources
|
// Close stops the cleanup routine and releases resources
|
||||||
func (t *ICMPTracker) Close() {
|
func (t *ICMPTracker) Close() {
|
||||||
t.cleanupTicker.Stop()
|
t.tickerCancel()
|
||||||
close(t.done)
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
for _, conn := range t.connections {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package conntrack
|
|||||||
// TODO: Send RST packets for invalid/timed-out connections
|
// TODO: Send RST packets for invalid/timed-out connections
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -85,23 +86,26 @@ type TCPTracker struct {
|
|||||||
connections map[ConnKey]*TCPConnTrack
|
connections map[ConnKey]*TCPConnTrack
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
done chan struct{}
|
tickerCancel context.CancelFunc
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
ipPool *PreallocatedIPs
|
ipPool *PreallocatedIPs
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPTracker creates a new TCP connection tracker
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
tracker := &TCPTracker{
|
tracker := &TCPTracker{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
connections: make(map[ConnKey]*TCPConnTrack),
|
connections: make(map[ConnKey]*TCPConnTrack),
|
||||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
done: make(chan struct{}),
|
tickerCancel: cancel,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
ipPool: NewPreallocatedIPs(),
|
ipPool: NewPreallocatedIPs(),
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine()
|
go tracker.cleanupRoutine(ctx)
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,12 +319,14 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCPTracker) cleanupRoutine() {
|
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
|
||||||
|
defer t.cleanupTicker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.cleanupTicker.C:
|
case <-t.cleanupTicker.C:
|
||||||
t.cleanup()
|
t.cleanup()
|
||||||
case <-t.done:
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -355,8 +361,7 @@ func (t *TCPTracker) cleanup() {
|
|||||||
|
|
||||||
// Close stops the cleanup routine and releases resources
|
// Close stops the cleanup routine and releases resources
|
||||||
func (t *TCPTracker) Close() {
|
func (t *TCPTracker) Close() {
|
||||||
t.cleanupTicker.Stop()
|
t.tickerCancel()
|
||||||
close(t.done)
|
|
||||||
|
|
||||||
// Clean up all remaining IPs
|
// Clean up all remaining IPs
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -26,8 +27,8 @@ type UDPTracker struct {
|
|||||||
connections map[ConnKey]*UDPConnTrack
|
connections map[ConnKey]*UDPConnTrack
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
|
tickerCancel context.CancelFunc
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
done chan struct{}
|
|
||||||
ipPool *PreallocatedIPs
|
ipPool *PreallocatedIPs
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -37,16 +38,18 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
|||||||
timeout = DefaultUDPTimeout
|
timeout = DefaultUDPTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
tracker := &UDPTracker{
|
tracker := &UDPTracker{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
connections: make(map[ConnKey]*UDPConnTrack),
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||||
done: make(chan struct{}),
|
tickerCancel: cancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
ipPool: NewPreallocatedIPs(),
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine()
|
go tracker.cleanupRoutine(ctx)
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,12 +106,14 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// cleanupRoutine periodically removes stale connections
|
// cleanupRoutine periodically removes stale connections
|
||||||
func (t *UDPTracker) cleanupRoutine() {
|
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
|
||||||
|
defer t.cleanupTicker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.cleanupTicker.C:
|
case <-t.cleanupTicker.C:
|
||||||
t.cleanup()
|
t.cleanup()
|
||||||
case <-t.done:
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -131,8 +136,7 @@ func (t *UDPTracker) cleanup() {
|
|||||||
|
|
||||||
// Close stops the cleanup routine and releases resources
|
// Close stops the cleanup routine and releases resources
|
||||||
func (t *UDPTracker) Close() {
|
func (t *UDPTracker) Close() {
|
||||||
t.cleanupTicker.Stop()
|
t.tickerCancel()
|
||||||
close(t.done)
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
for _, conn := range t.connections {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -34,7 +35,7 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||||
assert.NotNil(t, tracker.connections)
|
assert.NotNil(t, tracker.connections)
|
||||||
assert.NotNil(t, tracker.cleanupTicker)
|
assert.NotNil(t, tracker.cleanupTicker)
|
||||||
assert.NotNil(t, tracker.done)
|
assert.NotNil(t, tracker.tickerCancel)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -154,18 +155,21 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
timeout := 50 * time.Millisecond
|
timeout := 50 * time.Millisecond
|
||||||
cleanupInterval := 25 * time.Millisecond
|
cleanupInterval := 25 * time.Millisecond
|
||||||
|
|
||||||
|
ctx, tickerCancel := context.WithCancel(context.Background())
|
||||||
|
defer tickerCancel()
|
||||||
|
|
||||||
// Create tracker with custom cleanup interval
|
// Create tracker with custom cleanup interval
|
||||||
tracker := &UDPTracker{
|
tracker := &UDPTracker{
|
||||||
connections: make(map[ConnKey]*UDPConnTrack),
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
done: make(chan struct{}),
|
tickerCancel: tickerCancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
ipPool: NewPreallocatedIPs(),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
go tracker.cleanupRoutine()
|
go tracker.cleanupRoutine(ctx)
|
||||||
|
|
||||||
// Add some connections
|
// Add some connections
|
||||||
connections := []struct {
|
connections := []struct {
|
||||||
|
|||||||
@@ -6,19 +6,19 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLocalIPManager(t *testing.T) {
|
func TestLocalIPManager(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
setupAddr iface.WGAddress
|
setupAddr wgaddr.Address
|
||||||
testIP net.IP
|
testIP net.IP
|
||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Localhost range",
|
name: "Localhost range",
|
||||||
setupAddr: iface.WGAddress{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
@@ -30,7 +30,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Localhost standard address",
|
name: "Localhost standard address",
|
||||||
setupAddr: iface.WGAddress{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
@@ -42,7 +42,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Localhost range edge",
|
name: "Localhost range edge",
|
||||||
setupAddr: iface.WGAddress{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
@@ -54,7 +54,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Local IP matches",
|
name: "Local IP matches",
|
||||||
setupAddr: iface.WGAddress{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
@@ -66,7 +66,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Local IP doesn't match",
|
name: "Local IP doesn't match",
|
||||||
setupAddr: iface.WGAddress{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
@@ -78,7 +78,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6 address",
|
name: "IPv6 address",
|
||||||
setupAddr: iface.WGAddress{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("fe80::1"),
|
IP: net.ParseIP("fe80::1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("fe80::"),
|
IP: net.ParseIP("fe80::"),
|
||||||
@@ -95,7 +95,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
manager := newLocalIPManager()
|
manager := newLocalIPManager()
|
||||||
|
|
||||||
mock := &IFaceMock{
|
mock := &IFaceMock{
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return tt.setupAddr
|
return tt.setupAddr
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,8 +24,8 @@ type PeerRule struct {
|
|||||||
udpHook func([]byte) bool
|
udpHook func([]byte) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// ID returns the rule id
|
||||||
func (r *PeerRule) GetRuleID() string {
|
func (r *PeerRule) ID() string {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,7 +39,7 @@ type RouteRule struct {
|
|||||||
action firewall.Action
|
action firewall.Action
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// ID returns the rule id
|
||||||
func (r *RouteRule) GetRuleID() string {
|
func (r *RouteRule) ID() string {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,6 +42,8 @@ const (
|
|||||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
// RuleSet is a set of rules grouped by a string key
|
||||||
type RuleSet map[string]PeerRule
|
type RuleSet map[string]PeerRule
|
||||||
|
|
||||||
@@ -437,7 +439,7 @@ func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
ruleID := rule.GetRuleID()
|
ruleID := rule.ID()
|
||||||
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
||||||
return r.id == ruleID
|
return r.id == ruleID
|
||||||
})
|
})
|
||||||
@@ -478,6 +480,22 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return nil, errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// DropOutgoing filter outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
||||||
return m.processOutgoingHooks(packetData)
|
return m.processOutgoingHooks(packetData)
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -205,7 +205,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -253,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -452,7 +452,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Setup scenario
|
// Setup scenario
|
||||||
@@ -579,7 +579,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
manager.SetNetwork(&net.IPNet{
|
||||||
@@ -670,7 +670,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
manager.SetNetwork(&net.IPNet{
|
||||||
@@ -789,7 +789,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
manager.SetNetwork(&net.IPNet{
|
||||||
@@ -877,7 +877,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Reset(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
manager.SetNetwork(&net.IPNet{
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ import (
|
|||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPeerACLFiltering(t *testing.T) {
|
func TestPeerACLFiltering(t *testing.T) {
|
||||||
@@ -26,8 +26,8 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: localIP,
|
IP: localIP,
|
||||||
Network: wgNet,
|
Network: wgNet,
|
||||||
}
|
}
|
||||||
@@ -39,7 +39,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
require.NotNil(t, manager)
|
require.NotNil(t, manager)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Reset(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = wgNet
|
manager.wgNetwork = wgNet
|
||||||
@@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: localIP,
|
IP: localIP,
|
||||||
Network: wgNet,
|
Network: wgNet,
|
||||||
}
|
}
|
||||||
@@ -310,7 +310,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
require.False(tb, manager.nativeRouter)
|
require.False(tb, manager.nativeRouter)
|
||||||
|
|
||||||
tb.Cleanup(func() {
|
tb.Cleanup(func() {
|
||||||
require.NoError(tb, manager.Reset(nil))
|
require.NoError(tb, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
return manager
|
return manager
|
||||||
|
|||||||
@@ -16,15 +16,15 @@ import (
|
|||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
AddressFunc func() iface.WGAddress
|
AddressFunc func() wgaddr.Address
|
||||||
GetWGDeviceFunc func() *wgdevice.Device
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
GetDeviceFunc func() *device.FilteredDevice
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
}
|
}
|
||||||
@@ -50,9 +50,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
|||||||
return i.SetFilterFunc(iface)
|
return i.SetFilterFunc(iface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) Address() iface.WGAddress {
|
func (i *IFaceMock) Address() wgaddr.Address {
|
||||||
if i.AddressFunc == nil {
|
if i.AddressFunc == nil {
|
||||||
return iface.WGAddress{}
|
return wgaddr.Address{}
|
||||||
}
|
}
|
||||||
return i.AddressFunc()
|
return i.AddressFunc()
|
||||||
}
|
}
|
||||||
@@ -135,7 +135,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -149,7 +149,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
|
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -254,7 +254,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.Reset(nil)
|
err = m.Close(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to reset Manager: %v", err)
|
t.Errorf("failed to reset Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -268,8 +268,8 @@ func TestManagerReset(t *testing.T) {
|
|||||||
func TestNotMatchByIP(t *testing.T) {
|
func TestNotMatchByIP(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.10.0.100"),
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
@@ -333,7 +333,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = m.Reset(nil); err != nil {
|
if err = m.Close(nil); err != nil {
|
||||||
t.Errorf("failed to reset Manager: %v", err)
|
t.Errorf("failed to reset Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -352,7 +352,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Reset(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Add a UDP packet hook
|
// Add a UDP packet hook
|
||||||
@@ -403,7 +403,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Reset(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
@@ -484,7 +484,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(nil); err != nil {
|
if err := manager.Close(nil); err != nil {
|
||||||
t.Errorf("clear the manager state: %v", err)
|
t.Errorf("clear the manager state: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -530,7 +530,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Reset(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Set up packet parameters
|
// Set up packet parameters
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/pion/stun/v2"
|
"github.com/pion/stun/v2"
|
||||||
@@ -14,6 +13,8 @@ import (
|
|||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RecvMessage struct {
|
type RecvMessage struct {
|
||||||
@@ -52,9 +53,10 @@ type ICEBind struct {
|
|||||||
|
|
||||||
muUDPMux sync.Mutex
|
muUDPMux sync.Mutex
|
||||||
udpMux *UniversalUDPMuxDefault
|
udpMux *UniversalUDPMuxDefault
|
||||||
|
address wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||||
ib := &ICEBind{
|
ib := &ICEBind{
|
||||||
StdNetBind: b,
|
StdNetBind: b,
|
||||||
@@ -64,6 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
|||||||
endpoints: make(map[netip.Addr]net.Conn),
|
endpoints: make(map[netip.Addr]net.Conn),
|
||||||
closedChan: make(chan struct{}),
|
closedChan: make(chan struct{}),
|
||||||
closed: true,
|
closed: true,
|
||||||
|
address: address,
|
||||||
}
|
}
|
||||||
|
|
||||||
rc := receiverCreator{
|
rc := receiverCreator{
|
||||||
@@ -108,35 +111,17 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
|||||||
return s.udpMux, nil
|
return s.udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
|
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||||
fakeUDPAddr, err := fakeAddress(peerAddress)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// force IPv4
|
|
||||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
|
|
||||||
}
|
|
||||||
|
|
||||||
b.endpointsMu.Lock()
|
b.endpointsMu.Lock()
|
||||||
b.endpoints[fakeAddr] = conn
|
b.endpoints[fakeIP] = conn
|
||||||
b.endpointsMu.Unlock()
|
b.endpointsMu.Unlock()
|
||||||
|
|
||||||
return fakeUDPAddr, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
|
func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
|
||||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
|
||||||
if !ok {
|
|
||||||
log.Warnf("failed to convert IP to netip.Addr")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
b.endpointsMu.Lock()
|
b.endpointsMu.Lock()
|
||||||
defer b.endpointsMu.Unlock()
|
defer b.endpointsMu.Unlock()
|
||||||
delete(b.endpoints, fakeAddr)
|
|
||||||
|
delete(b.endpoints, fakeIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||||
@@ -161,9 +146,10 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
|
|
||||||
s.udpMux = NewUniversalUDPMuxDefault(
|
s.udpMux = NewUniversalUDPMuxDefault(
|
||||||
UniversalUDPMuxParams{
|
UniversalUDPMuxParams{
|
||||||
UDPConn: conn,
|
UDPConn: conn,
|
||||||
Net: s.transportNet,
|
Net: s.transportNet,
|
||||||
FilterFn: s.filterFn,
|
FilterFn: s.filterFn,
|
||||||
|
WGAddress: s.address,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
@@ -275,21 +261,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
|
||||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
|
||||||
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
|
||||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
|
||||||
if len(octets) != 4 {
|
|
||||||
return nil, fmt.Errorf("invalid IP format")
|
|
||||||
}
|
|
||||||
|
|
||||||
newAddr := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
|
|
||||||
Port: peerAddress.Port,
|
|
||||||
}
|
|
||||||
return newAddr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||||
return msgsPool.Get().(*[]ipv6.Message)
|
return msgsPool.Get().(*[]ipv6.Message)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ import (
|
|||||||
"github.com/pion/logging"
|
"github.com/pion/logging"
|
||||||
"github.com/pion/stun/v2"
|
"github.com/pion/stun/v2"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FilterFn is a function that filters out candidates based on the address.
|
// FilterFn is a function that filters out candidates based on the address.
|
||||||
@@ -41,6 +43,7 @@ type UniversalUDPMuxParams struct {
|
|||||||
XORMappedAddrCacheTTL time.Duration
|
XORMappedAddrCacheTTL time.Duration
|
||||||
Net transport.Net
|
Net transport.Net
|
||||||
FilterFn FilterFn
|
FilterFn FilterFn
|
||||||
|
WGAddress wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||||
@@ -64,6 +67,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
|||||||
mux: m,
|
mux: m,
|
||||||
logger: params.Logger,
|
logger: params.Logger,
|
||||||
filterFn: params.FilterFn,
|
filterFn: params.FilterFn,
|
||||||
|
address: params.WGAddress,
|
||||||
}
|
}
|
||||||
|
|
||||||
// embed UDPMux
|
// embed UDPMux
|
||||||
@@ -118,6 +122,7 @@ type udpConn struct {
|
|||||||
filterFn FilterFn
|
filterFn FilterFn
|
||||||
// TODO: reset cache on route changes
|
// TODO: reset cache on route changes
|
||||||
addrCache sync.Map
|
addrCache sync.Map
|
||||||
|
address wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||||
@@ -159,6 +164,11 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if u.address.Network.Contains(a.AsSlice()) {
|
||||||
|
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
|
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
|
}
|
||||||
|
|
||||||
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
||||||
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -9,13 +9,14 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGTunDevice interface {
|
type WGTunDevice interface {
|
||||||
Create() (device.WGConfigurer, error)
|
Create() (device.WGConfigurer, error)
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(address WGAddress) error
|
UpdateAddr(address wgaddr.Address) error
|
||||||
WgAddress() WGAddress
|
WgAddress() wgaddr.Address
|
||||||
DeviceName() string
|
DeviceName() string
|
||||||
Close() error
|
Close() error
|
||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
|
|||||||
@@ -13,11 +13,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||||
type WGTunDevice struct {
|
type WGTunDevice struct {
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -31,7 +32,7 @@ type WGTunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||||
return &WGTunDevice{
|
return &WGTunDevice{
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
@@ -93,7 +94,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
|
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||||
// todo implement
|
// todo implement
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -123,7 +124,7 @@ func (t *WGTunDevice) DeviceName() string {
|
|||||||
return t.name
|
return t.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WGTunDevice) WgAddress() WGAddress {
|
func (t *WGTunDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,11 +13,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunDevice struct {
|
type TunDevice struct {
|
||||||
name string
|
name string
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -29,7 +30,7 @@ type TunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||||
return &TunDevice{
|
return &TunDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@@ -85,7 +86,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) UpdateAddr(address WGAddress) error {
|
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
||||||
t.address = address
|
t.address = address
|
||||||
return t.assignAddr()
|
return t.assignAddr()
|
||||||
}
|
}
|
||||||
@@ -106,7 +107,7 @@ func (t *TunDevice) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) WgAddress() WGAddress {
|
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunDevice struct {
|
type TunDevice struct {
|
||||||
name string
|
name string
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
iceBind *bind.ICEBind
|
iceBind *bind.ICEBind
|
||||||
@@ -30,7 +31,7 @@ type TunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
|
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
|
||||||
return &TunDevice{
|
return &TunDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@@ -120,11 +121,11 @@ func (t *TunDevice) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) WgAddress() WGAddress {
|
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) UpdateAddr(addr WGAddress) error {
|
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
|
||||||
// todo implement
|
// todo implement
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,12 +14,13 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/sharedsock"
|
"github.com/netbirdio/netbird/sharedsock"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunKernelDevice struct {
|
type TunKernelDevice struct {
|
||||||
name string
|
name string
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
wgPort int
|
wgPort int
|
||||||
key string
|
key string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -34,7 +35,7 @@ type TunKernelDevice struct {
|
|||||||
filterFn bind.FilterFn
|
filterFn bind.FilterFn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &TunKernelDevice{
|
return &TunKernelDevice{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -99,9 +100,10 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
bindParams := bind.UniversalUDPMuxParams{
|
bindParams := bind.UniversalUDPMuxParams{
|
||||||
UDPConn: rawSock,
|
UDPConn: rawSock,
|
||||||
Net: t.transportNet,
|
Net: t.transportNet,
|
||||||
FilterFn: t.filterFn,
|
FilterFn: t.filterFn,
|
||||||
|
WGAddress: t.address,
|
||||||
}
|
}
|
||||||
mux := bind.NewUniversalUDPMuxDefault(bindParams)
|
mux := bind.NewUniversalUDPMuxDefault(bindParams)
|
||||||
go mux.ReadFromConn(t.ctx)
|
go mux.ReadFromConn(t.ctx)
|
||||||
@@ -112,7 +114,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return t.udpMux, nil
|
return t.udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
|
func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
|
||||||
t.address = address
|
t.address = address
|
||||||
return t.assignAddr()
|
return t.assignAddr()
|
||||||
}
|
}
|
||||||
@@ -145,7 +147,7 @@ func (t *TunKernelDevice) Close() error {
|
|||||||
return closErr
|
return closErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunKernelDevice) WgAddress() WGAddress {
|
func (t *TunKernelDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,12 +13,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunNetstackDevice struct {
|
type TunNetstackDevice struct {
|
||||||
name string
|
name string
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -34,7 +35,7 @@ type TunNetstackDevice struct {
|
|||||||
net *netstack.Net
|
net *netstack.Net
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||||
return &TunNetstackDevice{
|
return &TunNetstackDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@@ -97,7 +98,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
|
func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,7 +117,7 @@ func (t *TunNetstackDevice) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunNetstackDevice) WgAddress() WGAddress {
|
func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type USPDevice struct {
|
type USPDevice struct {
|
||||||
name string
|
name string
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -28,7 +29,7 @@ type USPDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||||
log.Infof("using userspace bind mode")
|
log.Infof("using userspace bind mode")
|
||||||
|
|
||||||
return &USPDevice{
|
return &USPDevice{
|
||||||
@@ -93,7 +94,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *USPDevice) UpdateAddr(address WGAddress) error {
|
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
|
||||||
t.address = address
|
t.address = address
|
||||||
return t.assignAddr()
|
return t.assignAddr()
|
||||||
}
|
}
|
||||||
@@ -113,7 +114,7 @@ func (t *USPDevice) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *USPDevice) WgAddress() WGAddress {
|
func (t *USPDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,13 +13,14 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
|
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
|
||||||
|
|
||||||
type TunDevice struct {
|
type TunDevice struct {
|
||||||
name string
|
name string
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -32,7 +33,7 @@ type TunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||||
return &TunDevice{
|
return &TunDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@@ -118,7 +119,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) UpdateAddr(address WGAddress) error {
|
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
||||||
t.address = address
|
t.address = address
|
||||||
return t.assignAddr()
|
return t.assignAddr()
|
||||||
}
|
}
|
||||||
@@ -139,7 +140,7 @@ func (t *TunDevice) Close() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (t *TunDevice) WgAddress() WGAddress {
|
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/freebsd"
|
"github.com/netbirdio/netbird/client/iface/freebsd"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wgLink struct {
|
type wgLink struct {
|
||||||
@@ -56,7 +57,7 @@ func (l *wgLink) up() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *wgLink) assignAddr(address WGAddress) error {
|
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||||
link, err := freebsd.LinkByName(l.name)
|
link, err := freebsd.LinkByName(l.name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("link by name: %w", err)
|
return fmt.Errorf("link by name: %w", err)
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wgLink struct {
|
type wgLink struct {
|
||||||
@@ -90,7 +92,7 @@ func (l *wgLink) up() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *wgLink) assignAddr(address WGAddress) error {
|
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||||
//delete existing addresses
|
//delete existing addresses
|
||||||
list, err := netlink.AddrList(l, 0)
|
list, err := netlink.AddrList(l, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -7,13 +7,14 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGTunDevice interface {
|
type WGTunDevice interface {
|
||||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(address WGAddress) error
|
UpdateAddr(address wgaddr.Address) error
|
||||||
WgAddress() WGAddress
|
WgAddress() wgaddr.Address
|
||||||
DeviceName() string
|
DeviceName() string
|
||||||
Close() error
|
Close() error
|
||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,8 +29,6 @@ const (
|
|||||||
WgInterfaceDefault = configurer.WgInterfaceDefault
|
WgInterfaceDefault = configurer.WgInterfaceDefault
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGAddress = device.WGAddress
|
|
||||||
|
|
||||||
type wgProxyFactory interface {
|
type wgProxyFactory interface {
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
Free() error
|
Free() error
|
||||||
@@ -72,7 +71,7 @@ func (w *WGIface) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Address returns the interface address
|
// Address returns the interface address
|
||||||
func (w *WGIface) Address() device.WGAddress {
|
func (w *WGIface) Address() wgaddr.Address {
|
||||||
return w.tun.WgAddress()
|
return w.tun.WgAddress()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,7 +102,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
|||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
addr, err := device.ParseWGAddress(newAddr)
|
addr, err := wgaddr.ParseWGAddress(newAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,17 +3,18 @@ package iface
|
|||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||||
|
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
|
|||||||
@@ -6,17 +6,18 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||||
|
|
||||||
var tun WGTunDevice
|
var tun WGTunDevice
|
||||||
if netstack.IsEnabled() {
|
if netstack.IsEnabled() {
|
||||||
|
|||||||
@@ -5,17 +5,18 @@ package iface
|
|||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||||
|
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
|
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
|
||||||
|
|||||||
@@ -8,12 +8,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -21,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
wgIFace := &WGIface{}
|
wgIFace := &WGIface{}
|
||||||
|
|
||||||
if netstack.IsEnabled() {
|
if netstack.IsEnabled() {
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
wgIFace.userspaceBind = true
|
wgIFace.userspaceBind = true
|
||||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||||
@@ -34,7 +35,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
if device.ModuleTunIsLoaded() {
|
if device.ModuleTunIsLoaded() {
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
wgIFace.userspaceBind = true
|
wgIFace.userspaceBind = true
|
||||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||||
|
|||||||
@@ -4,16 +4,17 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||||
|
|
||||||
var tun WGTunDevice
|
var tun WGTunDevice
|
||||||
if netstack.IsEnabled() {
|
if netstack.IsEnabled() {
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
|||||||
|
|
||||||
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
|
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to parse NB_ETSTACK_SKIP_PROXY: %s", err)
|
log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
|
||||||
}
|
}
|
||||||
if skipProxy {
|
if skipProxy {
|
||||||
return nsTunDev, tunNet, nil
|
return nsTunDev, tunNet, nil
|
||||||
|
|||||||
@@ -1,29 +1,29 @@
|
|||||||
package device
|
package wgaddr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGAddress WireGuard parsed address
|
// Address WireGuard parsed address
|
||||||
type WGAddress struct {
|
type Address struct {
|
||||||
IP net.IP
|
IP net.IP
|
||||||
Network *net.IPNet
|
Network *net.IPNet
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||||
func ParseWGAddress(address string) (WGAddress, error) {
|
func ParseWGAddress(address string) (Address, error) {
|
||||||
ip, network, err := net.ParseCIDR(address)
|
ip, network, err := net.ParseCIDR(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WGAddress{}, err
|
return Address{}, err
|
||||||
}
|
}
|
||||||
return WGAddress{
|
return Address{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (addr WGAddress) String() string {
|
func (addr Address) String() string {
|
||||||
maskSize, _ := addr.Network.Mask.Size()
|
maskSize, _ := addr.Network.Mask.Size()
|
||||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
||||||
}
|
}
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -16,13 +17,13 @@ import (
|
|||||||
type ProxyBind struct {
|
type ProxyBind struct {
|
||||||
Bind *bind.ICEBind
|
Bind *bind.ICEBind
|
||||||
|
|
||||||
wgAddr *net.UDPAddr
|
fakeNetIP *netip.AddrPort
|
||||||
wgEndpoint *bind.Endpoint
|
wgBindEndpoint *bind.Endpoint
|
||||||
remoteConn net.Conn
|
remoteConn net.Conn
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
closeMu sync.Mutex
|
closeMu sync.Mutex
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
pausedMu sync.Mutex
|
pausedMu sync.Mutex
|
||||||
paused bool
|
paused bool
|
||||||
@@ -33,20 +34,24 @@ type ProxyBind struct {
|
|||||||
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
|
||||||
// WireGuard configuration.
|
// WireGuard configuration.
|
||||||
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
|
||||||
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
|
fakeNetIP, err := fakeAddress(nbAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.wgAddr = addr
|
p.fakeNetIP = fakeNetIP
|
||||||
p.wgEndpoint = addrToEndpoint(addr)
|
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
|
||||||
p.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
p.ctx, p.cancel = context.WithCancel(ctx)
|
p.ctx, p.cancel = context.WithCancel(ctx)
|
||||||
return err
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
|
||||||
return p.wgAddr
|
return &net.UDPAddr{
|
||||||
|
IP: p.fakeNetIP.Addr().AsSlice(),
|
||||||
|
Port: int(p.fakeNetIP.Port()),
|
||||||
|
Zone: p.fakeNetIP.Addr().Zone(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProxyBind) Work() {
|
func (p *ProxyBind) Work() {
|
||||||
@@ -54,6 +59,8 @@ func (p *ProxyBind) Work() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
|
||||||
|
|
||||||
p.pausedMu.Lock()
|
p.pausedMu.Lock()
|
||||||
p.paused = false
|
p.paused = false
|
||||||
p.pausedMu.Unlock()
|
p.pausedMu.Unlock()
|
||||||
@@ -93,7 +100,7 @@ func (p *ProxyBind) close() error {
|
|||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
p.Bind.RemoveEndpoint(p.wgAddr)
|
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
|
||||||
|
|
||||||
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
|
||||||
return rErr
|
return rErr
|
||||||
@@ -126,7 +133,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
msg := bind.RecvMessage{
|
msg := bind.RecvMessage{
|
||||||
Endpoint: p.wgEndpoint,
|
Endpoint: p.wgBindEndpoint,
|
||||||
Buffer: buf[:n],
|
Buffer: buf[:n],
|
||||||
}
|
}
|
||||||
p.Bind.RecvChan <- msg
|
p.Bind.RecvChan <- msg
|
||||||
@@ -134,8 +141,19 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
|
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
||||||
ip, _ := netip.AddrFromSlice(addr.IP.To4())
|
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
||||||
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
|
func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
|
||||||
return &bind.Endpoint{AddrPort: addrPort}
|
octets := strings.Split(peerAddress.IP.String(), ".")
|
||||||
|
if len(octets) != 4 {
|
||||||
|
return nil, fmt.Errorf("invalid IP format")
|
||||||
|
}
|
||||||
|
|
||||||
|
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse new IP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
|
||||||
|
return &netipAddr, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,8 +6,8 @@
|
|||||||
!define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network"
|
!define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network"
|
||||||
!define INSTALLER_NAME "netbird-installer.exe"
|
!define INSTALLER_NAME "netbird-installer.exe"
|
||||||
!define MAIN_APP_EXE "Netbird"
|
!define MAIN_APP_EXE "Netbird"
|
||||||
!define ICON "ui\\netbird.ico"
|
!define ICON "ui\\assets\\netbird.ico"
|
||||||
!define BANNER "ui\\banner.bmp"
|
!define BANNER "ui\\build\\banner.bmp"
|
||||||
!define LICENSE_DATA "..\\LICENSE"
|
!define LICENSE_DATA "..\\LICENSE"
|
||||||
|
|
||||||
!define INSTALL_DIR "$PROGRAMFILES64\${APP_NAME}"
|
!define INSTALL_DIR "$PROGRAMFILES64\${APP_NAME}"
|
||||||
@@ -22,6 +22,8 @@
|
|||||||
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
|
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
|
||||||
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
|
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
|
||||||
|
|
||||||
|
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
||||||
|
|
||||||
Unicode True
|
Unicode True
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
@@ -68,6 +70,9 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
!insertmacro MUI_PAGE_DIRECTORY
|
!insertmacro MUI_PAGE_DIRECTORY
|
||||||
|
|
||||||
|
; Custom page for autostart checkbox
|
||||||
|
Page custom AutostartPage AutostartPageLeave
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_INSTFILES
|
!insertmacro MUI_PAGE_INSTFILES
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_FINISH
|
!insertmacro MUI_PAGE_FINISH
|
||||||
@@ -80,8 +85,36 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
!insertmacro MUI_LANGUAGE "English"
|
!insertmacro MUI_LANGUAGE "English"
|
||||||
|
|
||||||
|
; Variables for autostart option
|
||||||
|
Var AutostartCheckbox
|
||||||
|
Var AutostartEnabled
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
|
; Function to create the autostart options page
|
||||||
|
Function AutostartPage
|
||||||
|
!insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows."
|
||||||
|
|
||||||
|
nsDialogs::Create 1018
|
||||||
|
Pop $0
|
||||||
|
|
||||||
|
${If} $0 == error
|
||||||
|
Abort
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
|
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
||||||
|
Pop $AutostartCheckbox
|
||||||
|
${NSD_Check} $AutostartCheckbox ; Default to checked
|
||||||
|
StrCpy $AutostartEnabled "1" ; Default to enabled
|
||||||
|
|
||||||
|
nsDialogs::Show
|
||||||
|
FunctionEnd
|
||||||
|
|
||||||
|
; Function to handle leaving the autostart page
|
||||||
|
Function AutostartPageLeave
|
||||||
|
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
||||||
|
FunctionEnd
|
||||||
|
|
||||||
Function GetAppFromCommand
|
Function GetAppFromCommand
|
||||||
Exch $1
|
Exch $1
|
||||||
Push $2
|
Push $2
|
||||||
@@ -163,6 +196,16 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
|
|||||||
|
|
||||||
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||||
|
|
||||||
|
; Create autostart registry entry based on checkbox
|
||||||
|
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||||
|
${If} $AutostartEnabled == "1"
|
||||||
|
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
|
||||||
|
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||||
|
${Else}
|
||||||
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
|
DetailPrint "Autostart not enabled by user"
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
EnVar::SetHKLM
|
EnVar::SetHKLM
|
||||||
EnVar::AddValueEx "path" "$INSTDIR"
|
EnVar::AddValueEx "path" "$INSTDIR"
|
||||||
|
|
||||||
@@ -186,7 +229,10 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
|||||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||||
|
|
||||||
# kill ui client
|
# kill ui client
|
||||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe`
|
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||||
|
|
||||||
|
; Remove autostart registry entry
|
||||||
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
|
|
||||||
# wait the service uninstall take unblock the executable
|
# wait the service uninstall take unblock the executable
|
||||||
Sleep 3000
|
Sleep 3000
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
type RuleID string
|
type RuleID string
|
||||||
|
|
||||||
func (r RuleID) GetRuleID() string {
|
func (r RuleID) ID() string {
|
||||||
return string(r)
|
return string(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -245,7 +245,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
|
|||||||
return "", fmt.Errorf("add route rule: %w", err)
|
return "", fmt.Errorf("add route rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return id.RuleID(addedRule.GetRuleID()), nil
|
return id.RuleID(addedRule.ID()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||||
@@ -515,7 +515,7 @@ func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
|||||||
for _, rules := range newRulePairs {
|
for _, rules := range newRulePairs {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
@@ -45,7 +45,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
@@ -58,7 +58,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func(fw manager.Manager) {
|
defer func(fw manager.Manager) {
|
||||||
_ = fw.Reset(nil)
|
_ = fw.Close(nil)
|
||||||
}(fw)
|
}(fw)
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
@@ -74,7 +74,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
t.Run("add extra rules", func(t *testing.T) {
|
t.Run("add extra rules", func(t *testing.T) {
|
||||||
existedPairs := map[string]struct{}{}
|
existedPairs := map[string]struct{}{}
|
||||||
for id := range acl.peerRulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
existedPairs[id.GetRuleID()] = struct{}{}
|
existedPairs[id.ID()] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove first rule
|
// remove first rule
|
||||||
@@ -100,7 +100,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
// check that old rule was removed
|
// check that old rule was removed
|
||||||
previousCount := 0
|
previousCount := 0
|
||||||
for id := range acl.peerRulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
if _, ok := existedPairs[id.GetRuleID()]; ok {
|
if _, ok := existedPairs[id.ID()]; ok {
|
||||||
previousCount++
|
previousCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -339,7 +339,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
@@ -352,7 +352,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer func(fw manager.Manager) {
|
defer func(fw manager.Manager) {
|
||||||
_ = fw.Reset(nil)
|
_ = fw.Close(nil)
|
||||||
}(fw)
|
}(fw)
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
iface "github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockIFaceMapper is a mock of IFaceMapper interface.
|
// MockIFaceMapper is a mock of IFaceMapper interface.
|
||||||
@@ -38,10 +38,10 @@ func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Address mocks base method.
|
// Address mocks base method.
|
||||||
func (m *MockIFaceMapper) Address() iface.WGAddress {
|
func (m *MockIFaceMapper) Address() wgaddr.Address {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Address")
|
ret := m.ctrl.Call(m, "Address")
|
||||||
ret0, _ := ret[0].(iface.WGAddress)
|
ret0, _ := ret[0].(wgaddr.Address)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func NewConnectClient(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run with main logic.
|
// Run with main logic.
|
||||||
func (c *ConnectClient) Run(runningChan chan error) error {
|
func (c *ConnectClient) Run(runningChan chan struct{}) error {
|
||||||
return c.run(MobileDependency{}, runningChan)
|
return c.run(MobileDependency{}, runningChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,7 +102,7 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
return c.run(mobileDependency, nil)
|
return c.run(mobileDependency, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
|
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
rec := c.statusRecorder
|
rec := c.statusRecorder
|
||||||
@@ -159,10 +159,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer c.statusRecorder.ClientStop()
|
defer c.statusRecorder.ClientStop()
|
||||||
runningChanOpen := true
|
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
// if context cancelled we not start new backoff cycle
|
// if context cancelled we not start new backoff cycle
|
||||||
if c.isContextCancelled() {
|
if c.ctx.Err() != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -282,10 +281,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
if runningChan != nil && runningChanOpen {
|
if runningChan != nil {
|
||||||
runningChan <- nil
|
select {
|
||||||
close(runningChan)
|
case runningChan <- struct{}{}:
|
||||||
runningChanOpen = false
|
default:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
@@ -379,15 +379,6 @@ func (c *ConnectClient) Stop() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) isContextCancelled() bool {
|
|
||||||
select {
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNetworkMapPersistence enables or disables network map persistence.
|
// SetNetworkMapPersistence enables or disables network map persistence.
|
||||||
// When enabled, the last received network map will be stored and can be retrieved
|
// When enabled, the last received network map will be stored and can be retrieved
|
||||||
// through the Engine's getLatestNetworkMap method. When disabled, any stored
|
// through the Engine's getLatestNetworkMap method. When disabled, any stored
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
@@ -37,9 +38,9 @@ func (w *mocWGIface) Name() string {
|
|||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *mocWGIface) Address() iface.WGAddress {
|
func (w *mocWGIface) Address() wgaddr.Address {
|
||||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}
|
}
|
||||||
@@ -1015,7 +1016,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
|
|||||||
mh.AssertExpectations(t)
|
mh.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset mocks
|
// Close mocks
|
||||||
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
|
||||||
mh.ExpectedCalls = nil
|
mh.ExpectedCalls = nil
|
||||||
mh.Calls = nil
|
mh.Calls = nil
|
||||||
|
|||||||
@@ -5,15 +5,15 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGIface defines subset methods of interface required for manager
|
// WGIface defines subset methods of interface required for manager
|
||||||
type WGIface interface {
|
type WGIface interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
ToInterface() *net.Interface
|
ToInterface() *net.Interface
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGIface defines subset methods of interface required for manager
|
// WGIface defines subset methods of interface required for manager
|
||||||
type WGIface interface {
|
type WGIface interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
@@ -33,6 +33,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
@@ -169,10 +170,11 @@ type Engine struct {
|
|||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
firewall manager.Manager
|
firewall firewallManager.Manager
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
acl acl.Manager
|
acl acl.Manager
|
||||||
dnsForwardMgr *dnsfwd.Manager
|
dnsForwardMgr *dnsfwd.Manager
|
||||||
|
ingressGatewayMgr *ingressgw.Manager
|
||||||
|
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
|
|
||||||
@@ -266,6 +268,13 @@ func (e *Engine) Stop() error {
|
|||||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||||
e.stopDNSServer()
|
e.stopDNSServer()
|
||||||
|
|
||||||
|
if e.ingressGatewayMgr != nil {
|
||||||
|
if err := e.ingressGatewayMgr.Close(); err != nil {
|
||||||
|
log.Warnf("failed to cleanup forward rules: %v", err)
|
||||||
|
}
|
||||||
|
e.ingressGatewayMgr = nil
|
||||||
|
}
|
||||||
|
|
||||||
if e.routeManager != nil {
|
if e.routeManager != nil {
|
||||||
e.routeManager.Stop(e.stateManager)
|
e.routeManager.Stop(e.stateManager)
|
||||||
}
|
}
|
||||||
@@ -469,15 +478,15 @@ func (e *Engine) initFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rosenpassPort := e.rpManager.GetAddress().Port
|
rosenpassPort := e.rpManager.GetAddress().Port
|
||||||
port := manager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
||||||
|
|
||||||
// this rule is static and will be torn down on engine down by the firewall manager
|
// this rule is static and will be torn down on engine down by the firewall manager
|
||||||
if _, err := e.firewall.AddPeerFiltering(
|
if _, err := e.firewall.AddPeerFiltering(
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
manager.ProtocolUDP,
|
firewallManager.ProtocolUDP,
|
||||||
nil,
|
nil,
|
||||||
&port,
|
&port,
|
||||||
manager.ActionAccept,
|
firewallManager.ActionAccept,
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
); err != nil {
|
); err != nil {
|
||||||
@@ -505,10 +514,10 @@ func (e *Engine) blockLanAccess() {
|
|||||||
if _, err := e.firewall.AddRouteFiltering(
|
if _, err := e.firewall.AddRouteFiltering(
|
||||||
[]netip.Prefix{v4},
|
[]netip.Prefix{v4},
|
||||||
network,
|
network,
|
||||||
manager.ProtocolALL,
|
firewallManager.ProtocolALL,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
manager.ActionDrop,
|
firewallManager.ActionDrop,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err))
|
merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err))
|
||||||
}
|
}
|
||||||
@@ -912,6 +921,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ingress forward rules
|
||||||
|
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
|
||||||
|
log.Errorf("failed to update forward rules, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||||
|
|
||||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||||
@@ -1362,7 +1376,7 @@ func (e *Engine) close() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if e.firewall != nil {
|
if e.firewall != nil {
|
||||||
err := e.firewall.Reset(e.stateManager)
|
err := e.firewall.Close(e.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to reset firewall: %s", err)
|
log.Warnf("failed to reset firewall: %s", err)
|
||||||
}
|
}
|
||||||
@@ -1482,7 +1496,7 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetFirewallManager returns the firewall manager
|
// GetFirewallManager returns the firewall manager
|
||||||
func (e *Engine) GetFirewallManager() manager.Manager {
|
func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
||||||
return e.firewall
|
return e.firewall
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1575,16 +1589,19 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
|
|||||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
|
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// restartEngine restarts the engine by cancelling the client context
|
||||||
func (e *Engine) restartEngine() {
|
func (e *Engine) restartEngine() {
|
||||||
log.Info("restarting engine")
|
e.syncMsgMux.Lock()
|
||||||
CtxGetState(e.ctx).Set(StatusConnecting)
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
if err := e.Stop(); err != nil {
|
if e.ctx.Err() != nil {
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Info("restarting engine")
|
||||||
|
CtxGetState(e.ctx).Set(StatusConnecting)
|
||||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||||
log.Infof("cancelling client, engine will be recreated")
|
log.Infof("cancelling client context, engine will be recreated")
|
||||||
e.clientCancel()
|
e.clientCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1596,34 +1613,17 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
|
|
||||||
e.networkMonitor = networkmonitor.New()
|
e.networkMonitor = networkmonitor.New()
|
||||||
go func() {
|
go func() {
|
||||||
var mu sync.Mutex
|
if err := e.networkMonitor.Listen(e.ctx); err != nil {
|
||||||
var debounceTimer *time.Timer
|
if errors.Is(err, context.Canceled) {
|
||||||
|
log.Infof("network monitor stopped")
|
||||||
// Start the network monitor with a callback, Start will block until the monitor is stopped,
|
return
|
||||||
// a network change is detected, or an error occurs on start up
|
|
||||||
err := e.networkMonitor.Start(e.ctx, func() {
|
|
||||||
// This function is called when a network change is detected
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
|
|
||||||
if debounceTimer != nil {
|
|
||||||
log.Infof("Network monitor: detected network change, reset debounceTimer")
|
|
||||||
debounceTimer.Stop()
|
|
||||||
}
|
}
|
||||||
|
log.Errorf("network monitor error: %v", err)
|
||||||
// Set a new timer to debounce rapid network changes
|
return
|
||||||
debounceTimer = time.AfterFunc(2*time.Second, func() {
|
|
||||||
// This function is called after the debounce period
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
|
|
||||||
log.Infof("Network monitor: detected network change, restarting engine")
|
|
||||||
e.restartEngine()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
|
|
||||||
log.Errorf("Network monitor: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Infof("Network monitor: detected network change, restarting engine")
|
||||||
|
e.restartEngine()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1770,6 +1770,74 @@ func (e *Engine) Address() (netip.Addr, error) {
|
|||||||
return ip.Unmap(), nil
|
return ip.Unmap(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
|
||||||
|
if e.firewall == nil {
|
||||||
|
log.Warn("firewall is disabled, not updating forwarding rules")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rules) == 0 {
|
||||||
|
if e.ingressGatewayMgr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := e.ingressGatewayMgr.Close()
|
||||||
|
e.ingressGatewayMgr = nil
|
||||||
|
e.statusRecorder.SetIngressGwMgr(nil)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.ingressGatewayMgr == nil {
|
||||||
|
mgr := ingressgw.NewManager(e.firewall)
|
||||||
|
e.ingressGatewayMgr = mgr
|
||||||
|
e.statusRecorder.SetIngressGwMgr(mgr)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
|
||||||
|
for _, rule := range rules {
|
||||||
|
proto, err := convertToFirewallProtocol(rule.GetProtocol())
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("failed to convert protocol '%s': %w", rule.GetProtocol(), err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dstPortInfo, err := convertPortInfo(rule.GetDestinationPort())
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("invalid destination port '%v': %w", rule.GetDestinationPort(), err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
translateIP, err := convertToIP(rule.GetTranslatedAddress())
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("failed to convert translated address '%s': %w", rule.GetTranslatedAddress(), err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
translatePort, err := convertPortInfo(rule.GetTranslatedPort())
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("invalid translate port '%v': %w", rule.GetTranslatedPort(), err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
forwardRule := firewallManager.ForwardRule{
|
||||||
|
Protocol: proto,
|
||||||
|
DestinationPort: *dstPortInfo,
|
||||||
|
TranslatedAddress: translateIP,
|
||||||
|
TranslatedPort: *translatePort,
|
||||||
|
}
|
||||||
|
|
||||||
|
forwardingRules = append(forwardingRules, forwardRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("updating forwarding rules: %d", len(forwardingRules))
|
||||||
|
if err := e.ingressGatewayMgr.Update(forwardingRules); err != nil {
|
||||||
|
log.Errorf("failed to update forwarding rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||||
for _, check := range checks {
|
for _, check := range checks {
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -44,6 +45,7 @@ import (
|
|||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
@@ -74,7 +76,7 @@ type MockWGIface struct {
|
|||||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||||
IsUserspaceBindFunc func() bool
|
IsUserspaceBindFunc func() bool
|
||||||
NameFunc func() string
|
NameFunc func() string
|
||||||
AddressFunc func() device.WGAddress
|
AddressFunc func() wgaddr.Address
|
||||||
ToInterfaceFunc func() *net.Interface
|
ToInterfaceFunc func() *net.Interface
|
||||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddrFunc func(newAddr string) error
|
UpdateAddrFunc func(newAddr string) error
|
||||||
@@ -113,7 +115,7 @@ func (m *MockWGIface) Name() string {
|
|||||||
return m.NameFunc()
|
return m.NameFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) Address() device.WGAddress {
|
func (m *MockWGIface) Address() wgaddr.Address {
|
||||||
return m.AddressFunc()
|
return m.AddressFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -363,8 +365,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
RemovePeerFunc: func(peerKey string) error {
|
RemovePeerFunc: func(peerKey string) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
@@ -1433,7 +1435,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
|
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,7 +21,7 @@ type wgIfaceBase interface {
|
|||||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
Name() string
|
Name() string
|
||||||
Address() device.WGAddress
|
Address() wgaddr.Address
|
||||||
ToInterface() *net.Interface
|
ToInterface() *net.Interface
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(newAddr string) error
|
UpdateAddr(newAddr string) error
|
||||||
|
|||||||
107
client/internal/ingressgw/manager.go
Normal file
107
client/internal/ingressgw/manager.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package ingressgw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DNATFirewall interface {
|
||||||
|
AddDNATRule(fwdRule firewall.ForwardRule) (firewall.Rule, error)
|
||||||
|
DeleteDNATRule(rule firewall.Rule) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type RulePair struct {
|
||||||
|
firewall.ForwardRule
|
||||||
|
firewall.Rule
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
dnatFirewall DNATFirewall
|
||||||
|
|
||||||
|
rules map[string]RulePair // keys is the ID of the ForwardRule
|
||||||
|
rulesMu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(dnatFirewall DNATFirewall) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
dnatFirewall: dnatFirewall,
|
||||||
|
rules: make(map[string]RulePair),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
|
||||||
|
h.rulesMu.Lock()
|
||||||
|
defer h.rulesMu.Unlock()
|
||||||
|
|
||||||
|
var mErr *multierror.Error
|
||||||
|
|
||||||
|
toDelete := make(map[string]RulePair, len(h.rules))
|
||||||
|
for id, r := range h.rules {
|
||||||
|
toDelete[id] = r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process new/updated rules
|
||||||
|
for _, fwdRule := range forwardRules {
|
||||||
|
id := fwdRule.ID()
|
||||||
|
if _, ok := h.rules[id]; ok {
|
||||||
|
delete(toDelete, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, err := h.dnatFirewall.AddDNATRule(fwdRule)
|
||||||
|
if err != nil {
|
||||||
|
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': %v", fwdRule.String(), err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Infof("forward rule has been added '%s'", fwdRule)
|
||||||
|
h.rules[id] = RulePair{
|
||||||
|
ForwardRule: fwdRule,
|
||||||
|
Rule: rule,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove deleted rules
|
||||||
|
for id, rulePair := range toDelete {
|
||||||
|
if err := h.dnatFirewall.DeleteDNATRule(rulePair.Rule); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rulePair.ForwardRule.String(), err))
|
||||||
|
}
|
||||||
|
log.Infof("forward rule has been deleted '%s'", rulePair.ForwardRule)
|
||||||
|
delete(h.rules, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) Close() error {
|
||||||
|
h.rulesMu.Lock()
|
||||||
|
defer h.rulesMu.Unlock()
|
||||||
|
|
||||||
|
log.Infof("clean up all (%d) forward rules", len(h.rules))
|
||||||
|
var mErr *multierror.Error
|
||||||
|
for _, rule := range h.rules {
|
||||||
|
if err := h.dnatFirewall.DeleteDNATRule(rule.Rule); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rule, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.rules = make(map[string]RulePair)
|
||||||
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) Rules() []firewall.ForwardRule {
|
||||||
|
h.rulesMu.Lock()
|
||||||
|
defer h.rulesMu.Unlock()
|
||||||
|
|
||||||
|
rules := make([]firewall.ForwardRule, 0, len(h.rules))
|
||||||
|
for _, rulePair := range h.rules {
|
||||||
|
rules = append(rules, rulePair.ForwardRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules
|
||||||
|
}
|
||||||
281
client/internal/ingressgw/manager_test.go
Normal file
281
client/internal/ingressgw/manager_test.go
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
package ingressgw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ firewall.Rule = (*MocFwRule)(nil)
|
||||||
|
_ DNATFirewall = &MockDNATFirewall{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type MocFwRule struct {
|
||||||
|
id string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MocFwRule) ID() string {
|
||||||
|
return string(m.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockDNATFirewall struct {
|
||||||
|
throwError bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDNATFirewall) AddDNATRule(fwdRule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if m.throwError {
|
||||||
|
return nil, fmt.Errorf("moc error")
|
||||||
|
}
|
||||||
|
|
||||||
|
fwRule := &MocFwRule{
|
||||||
|
id: fwdRule.ID(),
|
||||||
|
}
|
||||||
|
return fwRule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDNATFirewall) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if m.throwError {
|
||||||
|
return fmt.Errorf("moc error")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDNATFirewall) forceToThrowErrors() {
|
||||||
|
m.throwError = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_AddRule(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
|
||||||
|
updates := []firewall.ForwardRule{
|
||||||
|
{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Protocol: firewall.ProtocolUDP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}}
|
||||||
|
|
||||||
|
if err := mgr.Update(updates); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != len(updates) {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_UpdateRule(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
ruleTCP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleUDP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolUDP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleUDP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != 1 {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
if rules[0].TranslatedAddress.String() != ruleUDP.TranslatedAddress.String() {
|
||||||
|
t.Errorf("unexpected rule: %v", rules[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if rules[0].TranslatedPort.String() != ruleUDP.TranslatedPort.String() {
|
||||||
|
t.Errorf("unexpected rule: %v", rules[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if rules[0].DestinationPort.String() != ruleUDP.DestinationPort.String() {
|
||||||
|
t.Errorf("unexpected rule: %v", rules[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if rules[0].Protocol != ruleUDP.Protocol {
|
||||||
|
t.Errorf("unexpected rule: %v", rules[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_ExtendRules(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
ruleTCP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleUDP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolUDP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP, ruleUDP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != 2 {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_UnderlingError(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
ruleTCP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleUDP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolUDP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fw.forceToThrowErrors()
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP, ruleUDP}); err == nil {
|
||||||
|
t.Errorf("expected error")
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != 1 {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Cleanup(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
ruleTCP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != 0 {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_DeleteBrokenRule(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
|
||||||
|
// force to throw errors when Add DNAT Rule
|
||||||
|
fw.forceToThrowErrors()
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
ruleTCP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err == nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != 0 {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
// simulate that to remove a broken rule
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Close(); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Close(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
ruleTCP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Close(); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != 0 {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
58
client/internal/message_convert.go
Normal file
58
client/internal/message_convert.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewallManager.Protocol, error) {
|
||||||
|
switch protocol {
|
||||||
|
case mgmProto.RuleProtocol_TCP:
|
||||||
|
return firewallManager.ProtocolTCP, nil
|
||||||
|
case mgmProto.RuleProtocol_UDP:
|
||||||
|
return firewallManager.ProtocolUDP, nil
|
||||||
|
case mgmProto.RuleProtocol_ICMP:
|
||||||
|
return firewallManager.ProtocolICMP, nil
|
||||||
|
case mgmProto.RuleProtocol_ALL:
|
||||||
|
return firewallManager.ProtocolALL, nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertPortInfo(portInfo *mgmProto.PortInfo) (*firewallManager.Port, error) {
|
||||||
|
if portInfo == nil {
|
||||||
|
return nil, errors.New("portInfo cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if portInfo.GetPort() != 0 {
|
||||||
|
return firewallManager.NewPort(int(portInfo.GetPort()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if portInfo.GetRange() != nil {
|
||||||
|
return firewallManager.NewPort(int(portInfo.GetRange().Start), int(portInfo.GetRange().End))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("invalid portInfo: %v", portInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToIP(rawIP []byte) (netip.Addr, error) {
|
||||||
|
if rawIP == nil {
|
||||||
|
return netip.Addr{}, errors.New("input bytes cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rawIP) != net.IPv4len && len(rawIP) != net.IPv6len {
|
||||||
|
return netip.Addr{}, fmt.Errorf("invalid IP length: %d", len(rawIP))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rawIP) == net.IPv4len {
|
||||||
|
return netip.AddrFrom4([4]byte(rawIP)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.AddrFrom16([16]byte(rawIP)), nil
|
||||||
|
}
|
||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
)
|
)
|
||||||
|
|
||||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to open routing socket: %v", err)
|
return fmt.Errorf("failed to open routing socket: %v", err)
|
||||||
@@ -28,18 +28,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
err := unix.Close(fd)
|
|
||||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
|
||||||
log.Debugf("Network monitor: closed routing socket: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ErrStopped
|
return ctx.Err()
|
||||||
default:
|
default:
|
||||||
buf := make([]byte, 2048)
|
buf := make([]byte, 2048)
|
||||||
n, err := unix.Read(fd, buf)
|
n, err := unix.Read(fd, buf)
|
||||||
@@ -76,11 +68,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
|||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
case unix.RTM_ADD:
|
case unix.RTM_ADD:
|
||||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||||
go callback()
|
return nil
|
||||||
case unix.RTM_DELETE:
|
case unix.RTM_DELETE:
|
||||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||||
go callback()
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
)
|
)
|
||||||
|
|
||||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
|
if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
|
||||||
return errors.New("no interfaces available")
|
return errors.New("no interfaces available")
|
||||||
}
|
}
|
||||||
@@ -31,8 +31,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ErrStopped
|
return ctx.Err()
|
||||||
|
|
||||||
// handle route changes
|
// handle route changes
|
||||||
case route := <-routeChan:
|
case route := <-routeChan:
|
||||||
// default route and main table
|
// default route and main table
|
||||||
@@ -43,12 +42,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
|||||||
// triggered on added/replaced routes
|
// triggered on added/replaced routes
|
||||||
case syscall.RTM_NEWROUTE:
|
case syscall.RTM_NEWROUTE:
|
||||||
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
|
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||||
go callback()
|
|
||||||
return nil
|
return nil
|
||||||
case syscall.RTM_DELROUTE:
|
case syscall.RTM_DELROUTE:
|
||||||
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
|
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
|
||||||
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
|
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
|
||||||
go callback()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
)
|
)
|
||||||
|
|
||||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
|
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
routeMonitor, err := systemops.NewRouteMonitor(ctx)
|
routeMonitor, err := systemops.NewRouteMonitor(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create route monitor: %w", err)
|
return fmt.Errorf("failed to create route monitor: %w", err)
|
||||||
@@ -24,20 +24,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ErrStopped
|
return ctx.Err()
|
||||||
case route := <-routeMonitor.RouteUpdates():
|
case route := <-routeMonitor.RouteUpdates():
|
||||||
if route.Destination.Bits() != 0 {
|
if route.Destination.Bits() != 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if routeChanged(route, nexthopv4, nexthopv6, callback) {
|
if routeChanged(route, nexthopv4, nexthopv6) {
|
||||||
break
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool {
|
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool {
|
||||||
intf := "<nil>"
|
intf := "<nil>"
|
||||||
if route.Interface != nil {
|
if route.Interface != nil {
|
||||||
intf = route.Interface.Name
|
intf = route.Interface.Name
|
||||||
@@ -51,18 +51,15 @@ func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Ne
|
|||||||
case systemops.RouteModified:
|
case systemops.RouteModified:
|
||||||
// TODO: get routing table to figure out if our route is affected for modified routes
|
// TODO: get routing table to figure out if our route is affected for modified routes
|
||||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
|
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
|
||||||
go callback()
|
|
||||||
return true
|
return true
|
||||||
case systemops.RouteAdded:
|
case systemops.RouteAdded:
|
||||||
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
|
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
|
||||||
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
|
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
|
||||||
go callback()
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
case systemops.RouteDeleted:
|
case systemops.RouteDeleted:
|
||||||
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
|
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
|
||||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
|
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
|
||||||
go callback()
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,12 +1,27 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
package networkmonitor
|
package networkmonitor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"runtime/debug"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrStopped = errors.New("monitor has been stopped")
|
const (
|
||||||
|
debounceTime = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var checkChangeFn = checkChange
|
||||||
|
|
||||||
// NetworkMonitor watches for changes in network configuration.
|
// NetworkMonitor watches for changes in network configuration.
|
||||||
type NetworkMonitor struct {
|
type NetworkMonitor struct {
|
||||||
@@ -19,3 +34,99 @@ type NetworkMonitor struct {
|
|||||||
func New() *NetworkMonitor {
|
func New() *NetworkMonitor {
|
||||||
return &NetworkMonitor{}
|
return &NetworkMonitor{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
||||||
|
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||||
|
nw.mu.Lock()
|
||||||
|
if nw.cancel != nil {
|
||||||
|
nw.mu.Unlock()
|
||||||
|
return errors.New("network monitor already started")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, nw.cancel = context.WithCancel(ctx)
|
||||||
|
defer nw.cancel()
|
||||||
|
nw.wg.Add(1)
|
||||||
|
nw.mu.Unlock()
|
||||||
|
|
||||||
|
defer nw.wg.Done()
|
||||||
|
|
||||||
|
var nexthop4, nexthop6 systemops.Nexthop
|
||||||
|
|
||||||
|
operation := func() error {
|
||||||
|
var errv4, errv6 error
|
||||||
|
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
|
||||||
|
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
|
||||||
|
|
||||||
|
if errv4 != nil && errv6 != nil {
|
||||||
|
return errors.New("failed to get default next hops")
|
||||||
|
}
|
||||||
|
|
||||||
|
if errv4 == nil {
|
||||||
|
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
|
||||||
|
}
|
||||||
|
if errv6 == nil {
|
||||||
|
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// continue if either route was found
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
|
||||||
|
|
||||||
|
if err := backoff.Retry(operation, expBackOff); err != nil {
|
||||||
|
return fmt.Errorf("failed to get default next hops: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// recover in case sys ops panic
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
event := make(chan struct{}, 1)
|
||||||
|
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
|
||||||
|
|
||||||
|
// debounce changes
|
||||||
|
timer := time.NewTimer(0)
|
||||||
|
timer.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-event:
|
||||||
|
timer.Reset(debounceTime)
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
timer.Stop()
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the network monitor.
|
||||||
|
func (nw *NetworkMonitor) Stop() {
|
||||||
|
nw.mu.Lock()
|
||||||
|
defer nw.mu.Unlock()
|
||||||
|
|
||||||
|
if nw.cancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
nw.cancel()
|
||||||
|
nw.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
|
||||||
|
for {
|
||||||
|
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
|
||||||
|
close(event)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// prevent blocking
|
||||||
|
select {
|
||||||
|
case event <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,82 +0,0 @@
|
|||||||
//go:build !ios && !android
|
|
||||||
|
|
||||||
package networkmonitor
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"runtime/debug"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
|
|
||||||
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
nw.mu.Lock()
|
|
||||||
ctx, nw.cancel = context.WithCancel(ctx)
|
|
||||||
nw.mu.Unlock()
|
|
||||||
|
|
||||||
nw.wg.Add(1)
|
|
||||||
defer nw.wg.Done()
|
|
||||||
|
|
||||||
var nexthop4, nexthop6 systemops.Nexthop
|
|
||||||
|
|
||||||
operation := func() error {
|
|
||||||
var errv4, errv6 error
|
|
||||||
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
|
|
||||||
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
|
|
||||||
|
|
||||||
if errv4 != nil && errv6 != nil {
|
|
||||||
return errors.New("failed to get default next hops")
|
|
||||||
}
|
|
||||||
|
|
||||||
if errv4 == nil {
|
|
||||||
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
|
|
||||||
}
|
|
||||||
if errv6 == nil {
|
|
||||||
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// continue if either route was found
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
|
|
||||||
|
|
||||||
if err := backoff.Retry(operation, expBackOff); err != nil {
|
|
||||||
return fmt.Errorf("failed to get default next hops: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// recover in case sys ops panic
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
|
|
||||||
return fmt.Errorf("check change: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop stops the network monitor.
|
|
||||||
func (nw *NetworkMonitor) Stop() {
|
|
||||||
nw.mu.Lock()
|
|
||||||
defer nw.mu.Unlock()
|
|
||||||
|
|
||||||
if nw.cancel != nil {
|
|
||||||
nw.cancel()
|
|
||||||
nw.wg.Wait()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,10 +2,21 @@
|
|||||||
|
|
||||||
package networkmonitor
|
package networkmonitor
|
||||||
|
|
||||||
import "context"
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
func (nw *NetworkMonitor) Start(context.Context, func()) error {
|
type NetworkMonitor struct {
|
||||||
return nil
|
}
|
||||||
|
|
||||||
|
// New creates a new network monitor.
|
||||||
|
func New() *NetworkMonitor {
|
||||||
|
return &NetworkMonitor{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nw *NetworkMonitor) Listen(_ context.Context) error {
|
||||||
|
return fmt.Errorf("network monitor not supported on mobile platforms")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (nw *NetworkMonitor) Stop() {
|
func (nw *NetworkMonitor) Stop() {
|
||||||
|
|||||||
99
client/internal/networkmonitor/monitor_test.go
Normal file
99
client/internal/networkmonitor/monitor_test.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MocMultiEvent struct {
|
||||||
|
counter int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MocMultiEvent) checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
|
if m.counter == 0 {
|
||||||
|
<-ctx.Done()
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
m.counter--
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMonitor_Close(t *testing.T) {
|
||||||
|
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
|
<-ctx.Done()
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
nw := New()
|
||||||
|
|
||||||
|
var resErr error
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
resErr = nw.Listen(context.Background())
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(1 * time.Second) // wait for the goroutine to start
|
||||||
|
nw.Stop()
|
||||||
|
|
||||||
|
<-done
|
||||||
|
if !errors.Is(resErr, context.Canceled) {
|
||||||
|
t.Errorf("unexpected error: %v", resErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMonitor_Event(t *testing.T) {
|
||||||
|
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
|
timeout, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-timeout.Done():
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
nw := New()
|
||||||
|
defer nw.Stop()
|
||||||
|
|
||||||
|
var resErr error
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
resErr = nw.Listen(context.Background())
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-done
|
||||||
|
if !errors.Is(resErr, nil) {
|
||||||
|
t.Errorf("unexpected error: %v", nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMonitor_MultiEvent(t *testing.T) {
|
||||||
|
eventsRepeated := 3
|
||||||
|
me := &MocMultiEvent{counter: eventsRepeated}
|
||||||
|
checkChangeFn = me.checkChange
|
||||||
|
|
||||||
|
nw := New()
|
||||||
|
defer nw.Stop()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
started := time.Now()
|
||||||
|
go func() {
|
||||||
|
if resErr := nw.Listen(context.Background()); resErr != nil {
|
||||||
|
t.Errorf("unexpected error: %v", resErr)
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-done
|
||||||
|
expectedResponseTime := time.Duration(eventsRepeated)*time.Second + debounceTime
|
||||||
|
if time.Since(started) < expectedResponseTime {
|
||||||
|
t.Errorf("unexpected duration: %v", time.Since(started))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -114,6 +114,9 @@ type Conn struct {
|
|||||||
|
|
||||||
guard *guard.Guard
|
guard *guard.Guard
|
||||||
semaphore *semaphoregroup.SemaphoreGroup
|
semaphore *semaphoregroup.SemaphoreGroup
|
||||||
|
|
||||||
|
// debug purpose
|
||||||
|
dumpState *stateDump
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConn creates a new not opened Conn to the remote peer.
|
// NewConn creates a new not opened Conn to the remote peer.
|
||||||
@@ -137,10 +140,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
|||||||
statusRelay: NewAtomicConnStatus(),
|
statusRelay: NewAtomicConnStatus(),
|
||||||
statusICE: NewAtomicConnStatus(),
|
statusICE: NewAtomicConnStatus(),
|
||||||
semaphore: semaphore,
|
semaphore: semaphore,
|
||||||
|
dumpState: newStateDump(connLog),
|
||||||
}
|
}
|
||||||
|
|
||||||
ctrl := isController(config)
|
ctrl := isController(config)
|
||||||
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager)
|
conn.workerRelay = NewWorkerRelay(connLog, ctrl, config, conn, relayManager, conn.dumpState)
|
||||||
|
|
||||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||||
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
|
workerICE, err := NewWorkerICE(ctx, connLog, config, conn, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally)
|
||||||
@@ -160,6 +164,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
|
|||||||
|
|
||||||
go conn.handshaker.Listen()
|
go conn.handshaker.Listen()
|
||||||
|
|
||||||
|
go conn.dumpState.Start(ctx)
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,6 +198,7 @@ func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) {
|
|||||||
defer conn.semaphore.Done(conn.ctx)
|
defer conn.semaphore.Done(conn.ctx)
|
||||||
conn.waitInitialRandomSleepTime(ctx)
|
conn.waitInitialRandomSleepTime(ctx)
|
||||||
|
|
||||||
|
conn.dumpState.SendOffer()
|
||||||
err := conn.handshaker.sendOffer()
|
err := conn.handshaker.sendOffer()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to send initial offer: %v", err)
|
conn.log.Errorf("failed to send initial offer: %v", err)
|
||||||
@@ -251,12 +257,14 @@ func (conn *Conn) Close() {
|
|||||||
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
|
||||||
// doesn't block, discards the message if connection wasn't ready
|
// doesn't block, discards the message if connection wasn't ready
|
||||||
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
|
func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
|
||||||
conn.log.Debugf("OnRemoteAnswer, status ICE: %s, status relay: %s", conn.statusICE, conn.statusRelay)
|
conn.dumpState.RemoteAnswer()
|
||||||
|
conn.log.Infof("OnRemoteAnswer, status ICE: %s, status relay: %s", conn.statusICE, conn.statusRelay)
|
||||||
return conn.handshaker.OnRemoteAnswer(answer)
|
return conn.handshaker.OnRemoteAnswer(answer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
|
||||||
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
|
||||||
|
conn.dumpState.RemoteCandidate()
|
||||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,7 +286,8 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string)) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
|
func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
|
||||||
conn.log.Debugf("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
|
conn.dumpState.RemoteOffer()
|
||||||
|
conn.log.Infof("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
|
||||||
return conn.handshaker.OnRemoteOffer(offer)
|
return conn.handshaker.OnRemoteOffer(offer)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -322,6 +331,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
|||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Infof("set ICE to active connection")
|
conn.log.Infof("set ICE to active connection")
|
||||||
|
conn.dumpState.P2PConnected()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ep *net.UDPAddr
|
ep *net.UDPAddr
|
||||||
@@ -329,6 +339,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC
|
|||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
if iceConnInfo.RelayedOnLocal {
|
if iceConnInfo.RelayedOnLocal {
|
||||||
|
conn.dumpState.NewLocalProxy()
|
||||||
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
|
wgProxy, err = conn.newProxy(iceConnInfo.RemoteConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
|
||||||
@@ -390,6 +401,7 @@ func (conn *Conn) onICEStateDisconnected() {
|
|||||||
// switch back to relay connection
|
// switch back to relay connection
|
||||||
if conn.isReadyToUpgrade() {
|
if conn.isReadyToUpgrade() {
|
||||||
conn.log.Infof("ICE disconnected, set Relay to active connection")
|
conn.log.Infof("ICE disconnected, set Relay to active connection")
|
||||||
|
conn.dumpState.SwitchToRelay()
|
||||||
conn.wgProxyRelay.Work()
|
conn.wgProxyRelay.Work()
|
||||||
|
|
||||||
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
|
||||||
@@ -432,6 +444,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
conn.dumpState.RelayConnected()
|
||||||
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
|
conn.log.Debugf("Relay connection has been established, setup the WireGuard")
|
||||||
|
|
||||||
wgProxy, err := conn.newProxy(rci.relayedConn)
|
wgProxy, err := conn.newProxy(rci.relayedConn)
|
||||||
@@ -439,11 +452,12 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
conn.dumpState.NewLocalProxy()
|
||||||
|
|
||||||
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
|
||||||
|
|
||||||
if conn.iceP2PIsActive() {
|
if conn.isICEActive() {
|
||||||
conn.log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
|
||||||
conn.setRelayedProxy(wgProxy)
|
conn.setRelayedProxy(wgProxy)
|
||||||
conn.statusRelay.Set(StatusConnected)
|
conn.statusRelay.Set(StatusConnected)
|
||||||
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
|
||||||
@@ -481,10 +495,10 @@ func (conn *Conn) onRelayDisconnected() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.log.Debugf("relay connection is disconnected")
|
conn.log.Infof("relay connection is disconnected")
|
||||||
|
|
||||||
if conn.currentConnPriority == connPriorityRelay {
|
if conn.currentConnPriority == connPriorityRelay {
|
||||||
conn.log.Debugf("clean up WireGuard config")
|
conn.log.Infof("clean up WireGuard config")
|
||||||
if err := conn.removeWgPeer(); err != nil {
|
if err := conn.removeWgPeer(); err != nil {
|
||||||
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
@@ -516,7 +530,8 @@ func (conn *Conn) listenGuardEvent(ctx context.Context) {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-conn.guard.Reconnect:
|
case <-conn.guard.Reconnect:
|
||||||
conn.log.Debugf("send offer to peer")
|
conn.log.Infof("send offer to peer")
|
||||||
|
conn.dumpState.SendOffer()
|
||||||
if err := conn.handshaker.SendOffer(); err != nil {
|
if err := conn.handshaker.SendOffer(); err != nil {
|
||||||
conn.log.Errorf("failed to send offer: %v", err)
|
conn.log.Errorf("failed to send offer: %v", err)
|
||||||
}
|
}
|
||||||
@@ -711,8 +726,8 @@ func (conn *Conn) isReadyToUpgrade() bool {
|
|||||||
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
|
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) iceP2PIsActive() bool {
|
func (conn *Conn) isICEActive() bool {
|
||||||
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
|
return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) removeWgPeer() error {
|
func (conn *Conn) removeWgPeer() error {
|
||||||
|
|||||||
@@ -76,19 +76,19 @@ func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAn
|
|||||||
|
|
||||||
func (h *Handshaker) Listen() {
|
func (h *Handshaker) Listen() {
|
||||||
for {
|
for {
|
||||||
h.log.Debugf("wait for remote offer confirmation")
|
h.log.Info("wait for remote offer confirmation")
|
||||||
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
|
remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var connectionClosedError *ConnectionClosedError
|
var connectionClosedError *ConnectionClosedError
|
||||||
if errors.As(err, &connectionClosedError) {
|
if errors.As(err, &connectionClosedError) {
|
||||||
h.log.Tracef("stop handshaker")
|
h.log.Info("exit from handshaker")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.log.Errorf("failed to received remote offer confirmation: %s", err)
|
h.log.Errorf("failed to received remote offer confirmation: %s", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
h.log.Debugf("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
|
h.log.Infof("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
|
||||||
for _, listener := range h.onNewOfferListeners {
|
for _, listener := range h.onNewOfferListeners {
|
||||||
go listener(remoteOfferAnswer)
|
go listener(remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
@@ -108,7 +108,7 @@ func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool {
|
|||||||
case h.remoteOffersCh <- offer:
|
case h.remoteOffersCh <- offer:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
h.log.Debugf("OnRemoteOffer skipping message because is not ready")
|
h.log.Warnf("OnRemoteOffer skipping message because is not ready")
|
||||||
// connection might not be ready yet to receive so we ignore the message
|
// connection might not be ready yet to receive so we ignore the message
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -131,8 +131,7 @@ func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
|
|||||||
select {
|
select {
|
||||||
case remoteOfferAnswer := <-h.remoteOffersCh:
|
case remoteOfferAnswer := <-h.remoteOffersCh:
|
||||||
// received confirmation from the remote peer -> ready to proceed
|
// received confirmation from the remote peer -> ready to proceed
|
||||||
err := h.sendAnswer()
|
if err := h.sendAnswer(); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &remoteOfferAnswer, nil
|
return &remoteOfferAnswer, nil
|
||||||
@@ -168,7 +167,7 @@ func (h *Handshaker) sendOffer() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) sendAnswer() error {
|
func (h *Handshaker) sendAnswer() error {
|
||||||
h.log.Debugf("sending answer")
|
h.log.Infof("sending answer")
|
||||||
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
uFrag, pwd := h.ice.GetLocalUserCredentials()
|
||||||
|
|
||||||
answer := OfferAnswer{
|
answer := OfferAnswer{
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,4 +17,5 @@ type WGIface interface {
|
|||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats(peerKey string) (configurer.WGStats, error)
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
|
Address() wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|||||||
112
client/internal/peer/state_dump.go
Normal file
112
client/internal/peer/state_dump.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stateDump struct {
|
||||||
|
log *log.Entry
|
||||||
|
|
||||||
|
sentOffer int
|
||||||
|
remoteOffer int
|
||||||
|
remoteAnswer int
|
||||||
|
remoteCandidate int
|
||||||
|
p2pConnected int
|
||||||
|
switchToRelay int
|
||||||
|
wgCheckSuccess int
|
||||||
|
relayConnected int
|
||||||
|
localProxies int
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStateDump(log *log.Entry) *stateDump {
|
||||||
|
return &stateDump{
|
||||||
|
log: log,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateDump) Start(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(10 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
s.dumpState()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateDump) RemoteOffer() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.remoteOffer++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateDump) RemoteCandidate() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.remoteCandidate++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateDump) SendOffer() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.sentOffer++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateDump) dumpState() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.log.Infof("Dump stat: SentOffer: %d, RemoteOffer: %d, RemoteAnswer: %d, RemoteCandidate: %d, P2PConnected: %d, SwitchToRelay: %d, WGCheckSuccess: %d, RelayConnected: %d, LocalProxies: %d",
|
||||||
|
s.sentOffer, s.remoteOffer, s.remoteAnswer, s.remoteCandidate, s.p2pConnected, s.switchToRelay, s.wgCheckSuccess, s.relayConnected, s.localProxies)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateDump) RemoteAnswer() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.remoteAnswer++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateDump) P2PConnected() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.p2pConnected++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateDump) SwitchToRelay() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.switchToRelay++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateDump) WGcheckSuccess() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.wgCheckSuccess++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateDump) RelayConnected() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.relayConnected++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateDump) NewLocalProxy() {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
s.localProxies++
|
||||||
|
}
|
||||||
@@ -14,7 +14,9 @@ import (
|
|||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/management/domain"
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
@@ -132,13 +134,14 @@ type NSGroupState struct {
|
|||||||
|
|
||||||
// FullStatus contains the full state held by the Status instance
|
// FullStatus contains the full state held by the Status instance
|
||||||
type FullStatus struct {
|
type FullStatus struct {
|
||||||
Peers []State
|
Peers []State
|
||||||
ManagementState ManagementState
|
ManagementState ManagementState
|
||||||
SignalState SignalState
|
SignalState SignalState
|
||||||
LocalPeerState LocalPeerState
|
LocalPeerState LocalPeerState
|
||||||
RosenpassState RosenpassState
|
RosenpassState RosenpassState
|
||||||
Relays []relay.ProbeResult
|
Relays []relay.ProbeResult
|
||||||
NSGroupStates []NSGroupState
|
NSGroupStates []NSGroupState
|
||||||
|
NumOfForwardingRules int
|
||||||
}
|
}
|
||||||
|
|
||||||
// Status holds a state of peers, signal, management connections and relays
|
// Status holds a state of peers, signal, management connections and relays
|
||||||
@@ -171,6 +174,8 @@ type Status struct {
|
|||||||
eventMux sync.RWMutex
|
eventMux sync.RWMutex
|
||||||
eventStreams map[string]chan *proto.SystemEvent
|
eventStreams map[string]chan *proto.SystemEvent
|
||||||
eventQueue *EventQueue
|
eventQueue *EventQueue
|
||||||
|
|
||||||
|
ingressGwMgr *ingressgw.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRecorder returns a new Status instance
|
// NewRecorder returns a new Status instance
|
||||||
@@ -193,6 +198,12 @@ func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
|
|||||||
d.relayMgr = manager
|
d.relayMgr = manager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Status) SetIngressGwMgr(ingressGwMgr *ingressgw.Manager) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
d.ingressGwMgr = ingressGwMgr
|
||||||
|
}
|
||||||
|
|
||||||
// ReplaceOfflinePeers replaces
|
// ReplaceOfflinePeers replaces
|
||||||
func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
func (d *Status) ReplaceOfflinePeers(replacement []State) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
@@ -235,6 +246,18 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
|||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
|
for _, state := range d.peers {
|
||||||
|
if state.IP == ip {
|
||||||
|
return state.FQDN, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
// RemovePeer removes peer from Daemon status map
|
// RemovePeer removes peer from Daemon status map
|
||||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
@@ -734,6 +757,16 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
|||||||
return append(relayStates, relayState)
|
return append(relayStates, relayState)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||||
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
if d.ingressGwMgr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.ingressGwMgr.Rules()
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Status) GetDNSStates() []NSGroupState {
|
func (d *Status) GetDNSStates() []NSGroupState {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
defer d.mux.Unlock()
|
defer d.mux.Unlock()
|
||||||
@@ -751,11 +784,12 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo
|
|||||||
// GetFullStatus gets full status
|
// GetFullStatus gets full status
|
||||||
func (d *Status) GetFullStatus() FullStatus {
|
func (d *Status) GetFullStatus() FullStatus {
|
||||||
fullStatus := FullStatus{
|
fullStatus := FullStatus{
|
||||||
ManagementState: d.GetManagementState(),
|
ManagementState: d.GetManagementState(),
|
||||||
SignalState: d.GetSignalState(),
|
SignalState: d.GetSignalState(),
|
||||||
Relays: d.GetRelayStates(),
|
Relays: d.GetRelayStates(),
|
||||||
RosenpassState: d.GetRosenpassState(),
|
RosenpassState: d.GetRosenpassState(),
|
||||||
NSGroupStates: d.GetDNSStates(),
|
NSGroupStates: d.GetDNSStates(),
|
||||||
|
NumOfForwardingRules: len(d.ForwardingRules()),
|
||||||
}
|
}
|
||||||
|
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ type WGWatcher struct {
|
|||||||
log *log.Entry
|
log *log.Entry
|
||||||
wgIfaceStater WGInterfaceStater
|
wgIfaceStater WGInterfaceStater
|
||||||
peerKey string
|
peerKey string
|
||||||
|
stateDump *stateDump
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
@@ -34,11 +35,12 @@ type WGWatcher struct {
|
|||||||
waitGroup sync.WaitGroup
|
waitGroup sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher {
|
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||||
return &WGWatcher{
|
return &WGWatcher{
|
||||||
log: log,
|
log: log,
|
||||||
wgIfaceStater: wgIfaceStater,
|
wgIfaceStater: wgIfaceStater,
|
||||||
peerKey: peerKey,
|
peerKey: peerKey,
|
||||||
|
stateDump: stateDump,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -105,6 +107,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
|
|||||||
|
|
||||||
resetTime := time.Until(handshake.Add(checkPeriod))
|
resetTime := time.Until(handshake.Add(checkPeriod))
|
||||||
timer.Reset(resetTime)
|
timer.Reset(resetTime)
|
||||||
|
w.stateDump.WGcheckSuccess()
|
||||||
|
|
||||||
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
w.log.Debugf("WireGuard watcher reset timer: %v", resetTime)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
|||||||
|
|
||||||
mlog := log.WithField("peer", "tet")
|
mlog := log.WithField("peer", "tet")
|
||||||
mocWgIface := &MocWgIface{}
|
mocWgIface := &MocWgIface{}
|
||||||
watcher := NewWGWatcher(mlog, mocWgIface, "")
|
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump(mlog))
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -72,7 +72,7 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
|||||||
|
|
||||||
mlog := log.WithField("peer", "tet")
|
mlog := log.WithField("peer", "tet")
|
||||||
mocWgIface := &MocWgIface{}
|
mocWgIface := &MocWgIface{}
|
||||||
watcher := NewWGWatcher(mlog, mocWgIface, "")
|
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump(mlog))
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|||||||
@@ -358,6 +358,12 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
|
|||||||
}
|
}
|
||||||
|
|
||||||
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
|
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
|
||||||
|
addr, err := netip.ParseAddr(candidate.Address())
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
var routePrefixes []netip.Prefix
|
var routePrefixes []netip.Prefix
|
||||||
for _, routes := range clientRoutes {
|
for _, routes := range clientRoutes {
|
||||||
if len(routes) > 0 && routes[0] != nil {
|
if len(routes) > 0 && routes[0] != nil {
|
||||||
@@ -365,14 +371,8 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, err := netip.ParseAddr(candidate.Address())
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, prefix := range routePrefixes {
|
for _, prefix := range routePrefixes {
|
||||||
// default route is
|
// default route is handled by route exclusion / ip rules
|
||||||
if prefix.Bits() == 0 {
|
if prefix.Bits() == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,14 +33,14 @@ type WorkerRelay struct {
|
|||||||
wgWatcher *WGWatcher
|
wgWatcher *WGWatcher
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService) *WorkerRelay {
|
func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, conn *Conn, relayManager relayClient.ManagerService, stateDump *stateDump) *WorkerRelay {
|
||||||
r := &WorkerRelay{
|
r := &WorkerRelay{
|
||||||
log: log,
|
log: log,
|
||||||
isController: ctrl,
|
isController: ctrl,
|
||||||
config: config,
|
config: config,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
relayManager: relayManager,
|
relayManager: relayManager,
|
||||||
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key),
|
wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key, stateDump),
|
||||||
}
|
}
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -302,7 +302,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error
|
|||||||
|
|
||||||
// If the chosen route is the same as the current route, do nothing
|
// If the chosen route is the same as the current route, do nothing
|
||||||
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
|
if c.currentChosen != nil && c.currentChosen.ID == newChosenID &&
|
||||||
c.currentChosen.IsEqual(c.routes[newChosenID]) {
|
c.currentChosen.Equal(c.routes[newChosenID]) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -160,6 +160,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
||||||
|
if r.Extra == nil {
|
||||||
|
r.SetEdns0(4096, false)
|
||||||
|
r.MsgHdr.AuthenticatedData = true
|
||||||
|
}
|
||||||
|
|
||||||
client := &dns.Client{
|
client := &dns.Client{
|
||||||
Timeout: 5 * time.Second,
|
Timeout: 5 * time.Second,
|
||||||
Net: "udp",
|
Net: "udp",
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ package iface
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wgIfaceBase interface {
|
type wgIfaceBase interface {
|
||||||
@@ -13,7 +13,7 @@ type wgIfaceBase interface {
|
|||||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||||
|
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
ToInterface() *net.Interface
|
ToInterface() *net.Interface
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
|
|||||||
51
client/internal/routemanager/ipfwdstate/ipfwdstate.go
Normal file
51
client/internal/routemanager/ipfwdstate/ipfwdstate.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package ipfwdstate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IPForwardingState is a struct that keeps track of the IP forwarding state.
|
||||||
|
// todo: read initial state of the IP forwarding from the system and reset the state based on it
|
||||||
|
type IPForwardingState struct {
|
||||||
|
enabledCounter int
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIPForwardingState() *IPForwardingState {
|
||||||
|
return &IPForwardingState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *IPForwardingState) RequestForwarding() error {
|
||||||
|
if f.enabledCounter != 0 {
|
||||||
|
f.enabledCounter++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := systemops.EnableIPForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("failed to enable IP forwarding with sysctl: %w", err)
|
||||||
|
}
|
||||||
|
f.enabledCounter = 1
|
||||||
|
log.Info("IP forwarding enabled")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *IPForwardingState) ReleaseForwarding() error {
|
||||||
|
if f.enabledCounter == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.enabledCounter > 1 {
|
||||||
|
f.enabledCounter--
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// if failed to disable IP forwarding we anyway decrement the counter
|
||||||
|
f.enabledCounter = 0
|
||||||
|
|
||||||
|
// todo call systemops.DisableIPForwarding()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -41,7 +40,7 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
|||||||
|
|
||||||
for routeID := range m.routes {
|
for routeID := range m.routes {
|
||||||
update, found := routesMap[routeID]
|
update, found := routesMap[routeID]
|
||||||
if !found || !update.IsEqual(m.routes[routeID]) {
|
if !found || !update.Equal(m.routes[routeID]) {
|
||||||
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
serverRoutesToRemove = append(serverRoutesToRemove, routeID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -71,9 +70,6 @@ func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(m.routes) > 0 {
|
if len(m.routes) > 0 {
|
||||||
if err := systemops.EnableIPForwarding(); err != nil {
|
|
||||||
return fmt.Errorf("enable ip forwarding: %w", err)
|
|
||||||
}
|
|
||||||
if err := m.firewall.EnableRouting(); err != nil {
|
if err := m.firewall.EnableRouting(); err != nil {
|
||||||
return fmt.Errorf("enable routing: %w", err)
|
return fmt.Errorf("enable routing: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user