mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 11:46:40 +00:00
Compare commits
27 Commits
debug-and-
...
v0.58.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25ed58328a | ||
|
|
644ed4b934 | ||
|
|
58faa341d2 | ||
|
|
5853b5553c | ||
|
|
998fb30e1e | ||
|
|
e254b4cde5 | ||
|
|
ead1c618ba | ||
|
|
55126f990c | ||
|
|
90577682e4 | ||
|
|
dc30dcacce | ||
|
|
2c87fa6236 | ||
|
|
ec8d83ade4 | ||
|
|
3130cce72d | ||
|
|
bd23ab925e | ||
|
|
0c6f671a7c | ||
|
|
cf7f6c355f | ||
|
|
47e64d72db | ||
|
|
9e81e782e5 | ||
|
|
7aef0f67df | ||
|
|
dba7ef667d | ||
|
|
69d87343d2 | ||
|
|
5113c70943 | ||
|
|
ad8fcda67b | ||
|
|
d33f88df82 | ||
|
|
786ca6fc79 | ||
|
|
dfebdf1444 | ||
|
|
a8dcff69c2 |
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@@ -217,7 +217,7 @@ jobs:
|
|||||||
- arch: "386"
|
- arch: "386"
|
||||||
raceFlag: ""
|
raceFlag: ""
|
||||||
- arch: "amd64"
|
- arch: "amd64"
|
||||||
raceFlag: ""
|
raceFlag: "-race"
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
|
|||||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.22"
|
SIGN_PIPE_VER: "v0.0.23"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<br/>
|
<br/>
|
||||||
<br/>
|
<br/>
|
||||||
@@ -52,7 +53,7 @@
|
|||||||
|
|
||||||
### Open Source Network Security in a Single Platform
|
### Open Source Network Security in a Single Platform
|
||||||
|
|
||||||
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
|
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||||
|
|
||||||
### NetBird on Lawrence Systems (Video)
|
### NetBird on Lawrence Systems (Video)
|
||||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ ENV \
|
|||||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
||||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
|
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
|
||||||
|
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package android
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -18,7 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/util/net"
|
"github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
@@ -83,7 +84,8 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||||
|
exportEnvList(envList)
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
@@ -118,7 +120,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
// In this case make no sense handle registration steps.
|
// In this case make no sense handle registration steps.
|
||||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
|
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||||
|
exportEnvList(envList)
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: c.cfgFile,
|
||||||
})
|
})
|
||||||
@@ -249,3 +252,14 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) {
|
|||||||
func (c *Client) RemoveConnectionListener() {
|
func (c *Client) RemoveConnectionListener() {
|
||||||
c.recorder.RemoveConnectionListener()
|
c.recorder.RemoveConnectionListener()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func exportEnvList(list *EnvList) {
|
||||||
|
if list == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for k, v := range list.AllItems() {
|
||||||
|
if err := os.Setenv(k, v); err != nil {
|
||||||
|
log.Errorf("could not set env variable %s: %v", k, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
32
client/android/env_list.go
Normal file
32
client/android/env_list.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package android
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// EnvKeyNBForceRelay Exported for Android java client
|
||||||
|
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
||||||
|
)
|
||||||
|
|
||||||
|
// EnvList wraps a Go map for export to Java
|
||||||
|
type EnvList struct {
|
||||||
|
data map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEnvList creates a new EnvList
|
||||||
|
func NewEnvList() *EnvList {
|
||||||
|
return &EnvList{data: make(map[string]string)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put adds a key-value pair
|
||||||
|
func (el *EnvList) Put(key, value string) {
|
||||||
|
el.data[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value by key
|
||||||
|
func (el *EnvList) Get(key string) string {
|
||||||
|
return el.data[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (el *EnvList) AllItems() map[string]string {
|
||||||
|
return el.data
|
||||||
|
}
|
||||||
@@ -33,6 +33,7 @@ type ErrListener interface {
|
|||||||
// the backend want to show an url for the user
|
// the backend want to show an url for the user
|
||||||
type URLOpener interface {
|
type URLOpener interface {
|
||||||
Open(string)
|
Open(string)
|
||||||
|
OnLoginSuccess()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auth can register or login new client
|
// Auth can register or login new client
|
||||||
@@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
err = a.withBackOff(a.ctx, func() error {
|
err = a.withBackOff(a.ctx, func() error {
|
||||||
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
err := internal.Login(a.ctx, a.config, "", jwtToken)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
go urlOpener.OnLoginSuccess()
|
||||||
|
}
|
||||||
|
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// update host's static platform and system information
|
// update host's static platform and system information
|
||||||
system.UpdateStaticInfo()
|
system.UpdateStaticInfoAsync()
|
||||||
|
|
||||||
configFilePath, err := activeProf.FilePath()
|
configFilePath, err := activeProf.FilePath()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
|
|||||||
|
|
||||||
// DialClientGRPCServer returns client connection to the daemon server.
|
// DialClientGRPCServer returns client connection to the daemon server.
|
||||||
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
|
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
return grpc.DialContext(
|
return grpc.DialContext(
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
log.Info("starting NetBird service") //nolint
|
log.Info("starting NetBird service") //nolint
|
||||||
|
|
||||||
// Collect static system and platform information
|
// Collect static system and platform information
|
||||||
system.UpdateStaticInfo()
|
system.UpdateStaticInfoAsync()
|
||||||
|
|
||||||
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
||||||
p.serv = grpc.NewServer()
|
p.serv = grpc.NewServer()
|
||||||
|
|||||||
@@ -9,29 +9,26 @@ import (
|
|||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
clientProto "github.com/netbirdio/netbird/client/proto"
|
||||||
|
client "github.com/netbirdio/netbird/client/server"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
mgmt "github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/peers"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
|
||||||
|
|
||||||
clientProto "github.com/netbirdio/netbird/client/proto"
|
|
||||||
client "github.com/netbirdio/netbird/client/server"
|
|
||||||
mgmt "github.com/netbirdio/netbird/management/server"
|
|
||||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
sigProto "github.com/netbirdio/netbird/shared/signal/proto"
|
sigProto "github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
sig "github.com/netbirdio/netbird/signal/server"
|
sig "github.com/netbirdio/netbird/signal/server"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func startTestingServices(t *testing.T) string {
|
func startTestingServices(t *testing.T) string {
|
||||||
@@ -90,15 +87,20 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
|
||||||
require.NoError(t, err)
|
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
t.Cleanup(ctrl.Finish)
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
|
||||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||||
|
peersmanager := peers.NewManager(store, permissionsManagerMock)
|
||||||
|
settingsManagerMock := settings.NewMockManager(ctrl)
|
||||||
|
|
||||||
|
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
settingsMockManager.EXPECT().
|
settingsMockManager.EXPECT().
|
||||||
|
|||||||
@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
|||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
|
||||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
status, err := client.Status(ctx, &proto.StatusRequest{
|
||||||
|
WaitForReady: func() *bool { b := true; return &b }(),
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
// TODO: make after-startup backoff err available
|
// TODO: make after-startup backoff err available
|
||||||
run := make(chan struct{}, 1)
|
run := make(chan struct{})
|
||||||
clientErr := make(chan error, 1)
|
clientErr := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
if err := client.Run(run); err != nil {
|
if err := client.Run(run); err != nil {
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// constants needed to manage and create iptable rules
|
// constants needed to manage and create iptable rules
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func isIptablesSupported() bool {
|
func isIptablesSupported() bool {
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -20,8 +20,9 @@ import (
|
|||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func WithCustomDialer() grpc.DialOption {
|
func WithCustomDialer() grpc.DialOption {
|
||||||
@@ -57,7 +58,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
return backoff.WithContext(b, ctx)
|
return backoff.WithContext(b, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
||||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||||
if tlsEnabled {
|
if tlsEnabled {
|
||||||
certPool, err := x509.SystemCertPool()
|
certPool, err := x509.SystemCertPool()
|
||||||
@@ -71,7 +72,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := grpc.DialContext(
|
conn, err := grpc.DialContext(
|
||||||
@@ -3,7 +3,7 @@ package bind
|
|||||||
import (
|
import (
|
||||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
||||||
|
|||||||
@@ -15,8 +15,9 @@ import (
|
|||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RecvMessage struct {
|
type RecvMessage struct {
|
||||||
@@ -44,7 +45,7 @@ type ICEBind struct {
|
|||||||
RecvChan chan RecvMessage
|
RecvChan chan RecvMessage
|
||||||
|
|
||||||
transportNet transport.Net
|
transportNet transport.Net
|
||||||
filterFn FilterFn
|
filterFn udpmux.FilterFn
|
||||||
endpoints map[netip.Addr]net.Conn
|
endpoints map[netip.Addr]net.Conn
|
||||||
endpointsMu sync.Mutex
|
endpointsMu sync.Mutex
|
||||||
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
|
||||||
@@ -54,13 +55,13 @@ type ICEBind struct {
|
|||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
muUDPMux sync.Mutex
|
muUDPMux sync.Mutex
|
||||||
udpMux *UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
address wgaddr.Address
|
address wgaddr.Address
|
||||||
mtu uint16
|
mtu uint16
|
||||||
activityRecorder *ActivityRecorder
|
activityRecorder *ActivityRecorder
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
|
||||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||||
ib := &ICEBind{
|
ib := &ICEBind{
|
||||||
StdNetBind: b,
|
StdNetBind: b,
|
||||||
@@ -115,7 +116,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
s.muUDPMux.Lock()
|
s.muUDPMux.Lock()
|
||||||
defer s.muUDPMux.Unlock()
|
defer s.muUDPMux.Unlock()
|
||||||
if s.udpMux == nil {
|
if s.udpMux == nil {
|
||||||
@@ -158,8 +159,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
s.muUDPMux.Lock()
|
s.muUDPMux.Lock()
|
||||||
defer s.muUDPMux.Unlock()
|
defer s.muUDPMux.Unlock()
|
||||||
|
|
||||||
s.udpMux = NewUniversalUDPMuxDefault(
|
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
|
||||||
UniversalUDPMuxParams{
|
udpmux.UniversalUDPMuxParams{
|
||||||
UDPConn: nbnet.WrapPacketConn(conn),
|
UDPConn: nbnet.WrapPacketConn(conn),
|
||||||
Net: s.transportNet,
|
Net: s.transportNet,
|
||||||
FilterFn: s.filterFn,
|
FilterFn: s.filterFn,
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
//go:build ios
|
|
||||||
|
|
||||||
package bind
|
|
||||||
|
|
||||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
|
||||||
// iOS doesn't support nbnet hooks, so this is a no-op
|
|
||||||
}
|
|
||||||
@@ -17,8 +17,8 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/monotime"
|
"github.com/netbirdio/netbird/monotime"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -394,6 +394,13 @@ func toLastHandshake(stringVar string) (time.Time, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
|
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If sec is 0 (Unix epoch), return zero time instead
|
||||||
|
// This indicates no handshake has occurred
|
||||||
|
if sec == 0 {
|
||||||
|
return time.Time{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
return time.Unix(sec, 0), nil
|
return time.Unix(sec, 0), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -402,7 +409,7 @@ func toBytes(s string) (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if nbnet.AdvancedRouting() {
|
if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
|
||||||
return nbnet.ControlPlaneMark
|
return nbnet.ControlPlaneMark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -7,14 +7,14 @@ import (
|
|||||||
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGTunDevice interface {
|
type WGTunDevice interface {
|
||||||
Create() (device.WGConfigurer, error)
|
Create() (device.WGConfigurer, error)
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(address wgaddr.Address) error
|
UpdateAddr(address wgaddr.Address) error
|
||||||
WgAddress() wgaddr.Address
|
WgAddress() wgaddr.Address
|
||||||
MTU() uint16
|
MTU() uint16
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,7 +30,7 @@ type WGTunDevice struct {
|
|||||||
name string
|
name string
|
||||||
device *device.Device
|
device *device.Device
|
||||||
filteredDevice *FilteredDevice
|
filteredDevice *FilteredDevice
|
||||||
udpMux *bind.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
}
|
}
|
||||||
return t.configurer, nil
|
return t.configurer, nil
|
||||||
}
|
}
|
||||||
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
err := t.device.Up()
|
err := t.device.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,7 +27,7 @@ type TunDevice struct {
|
|||||||
|
|
||||||
device *device.Device
|
device *device.Device
|
||||||
filteredDevice *FilteredDevice
|
filteredDevice *FilteredDevice
|
||||||
udpMux *bind.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
return t.configurer, nil
|
return t.configurer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
err := t.device.Up()
|
err := t.device.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,7 +29,7 @@ type TunDevice struct {
|
|||||||
|
|
||||||
device *device.Device
|
device *device.Device
|
||||||
filteredDevice *FilteredDevice
|
filteredDevice *FilteredDevice
|
||||||
udpMux *bind.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
return t.configurer, nil
|
return t.configurer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
err := t.device.Up()
|
err := t.device.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -12,11 +12,11 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/sharedsock"
|
"github.com/netbirdio/netbird/sharedsock"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunKernelDevice struct {
|
type TunKernelDevice struct {
|
||||||
@@ -31,9 +31,9 @@ type TunKernelDevice struct {
|
|||||||
|
|
||||||
link *wgLink
|
link *wgLink
|
||||||
udpMuxConn net.PacketConn
|
udpMuxConn net.PacketConn
|
||||||
udpMux *bind.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
|
|
||||||
filterFn bind.FilterFn
|
filterFn udpmux.FilterFn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
|
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
|
||||||
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
|
|||||||
return configurer, nil
|
return configurer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
if t.udpMux != nil {
|
if t.udpMux != nil {
|
||||||
return t.udpMux, nil
|
return t.udpMux, nil
|
||||||
}
|
}
|
||||||
@@ -101,19 +101,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var udpConn net.PacketConn = rawSock
|
bindParams := udpmux.UniversalUDPMuxParams{
|
||||||
if !nbnet.AdvancedRouting() {
|
UDPConn: nbnet.WrapPacketConn(rawSock),
|
||||||
udpConn = nbnet.WrapPacketConn(rawSock)
|
|
||||||
}
|
|
||||||
|
|
||||||
bindParams := bind.UniversalUDPMuxParams{
|
|
||||||
UDPConn: udpConn,
|
|
||||||
Net: t.transportNet,
|
Net: t.transportNet,
|
||||||
FilterFn: t.filterFn,
|
FilterFn: t.filterFn,
|
||||||
WGAddress: t.address,
|
WGAddress: t.address,
|
||||||
MTU: t.mtu,
|
MTU: t.mtu,
|
||||||
}
|
}
|
||||||
mux := bind.NewUniversalUDPMuxDefault(bindParams)
|
mux := udpmux.NewUniversalUDPMuxDefault(bindParams)
|
||||||
go mux.ReadFromConn(t.ctx)
|
go mux.ReadFromConn(t.ctx)
|
||||||
t.udpMuxConn = rawSock
|
t.udpMuxConn = rawSock
|
||||||
t.udpMux = mux
|
t.udpMux = mux
|
||||||
|
|||||||
@@ -10,8 +10,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunNetstackDevice struct {
|
type TunNetstackDevice struct {
|
||||||
@@ -26,7 +27,7 @@ type TunNetstackDevice struct {
|
|||||||
device *device.Device
|
device *device.Device
|
||||||
filteredDevice *FilteredDevice
|
filteredDevice *FilteredDevice
|
||||||
nsTun *nbnetstack.NetStackTun
|
nsTun *nbnetstack.NetStackTun
|
||||||
udpMux *bind.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
|
|
||||||
net *netstack.Net
|
net *netstack.Net
|
||||||
@@ -80,7 +81,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
|||||||
return t.configurer, nil
|
return t.configurer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
if t.device == nil {
|
if t.device == nil {
|
||||||
return nil, fmt.Errorf("device is not ready yet")
|
return nil, fmt.Errorf("device is not ready yet")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -25,7 +26,7 @@ type USPDevice struct {
|
|||||||
|
|
||||||
device *device.Device
|
device *device.Device
|
||||||
filteredDevice *FilteredDevice
|
filteredDevice *FilteredDevice
|
||||||
udpMux *bind.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
|
|||||||
return t.configurer, nil
|
return t.configurer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
if t.device == nil {
|
if t.device == nil {
|
||||||
return nil, fmt.Errorf("device is not ready yet")
|
return nil, fmt.Errorf("device is not ready yet")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,7 +30,7 @@ type TunDevice struct {
|
|||||||
device *device.Device
|
device *device.Device
|
||||||
nativeTunDevice *tun.NativeTun
|
nativeTunDevice *tun.NativeTun
|
||||||
filteredDevice *FilteredDevice
|
filteredDevice *FilteredDevice
|
||||||
udpMux *bind.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
return t.configurer, nil
|
return t.configurer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
err := t.device.Up()
|
err := t.device.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -5,14 +5,14 @@ import (
|
|||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGTunDevice interface {
|
type WGTunDevice interface {
|
||||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(address wgaddr.Address) error
|
UpdateAddr(address wgaddr.Address) error
|
||||||
WgAddress() wgaddr.Address
|
WgAddress() wgaddr.Address
|
||||||
MTU() uint16
|
MTU() uint16
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ import (
|
|||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
"github.com/netbirdio/netbird/monotime"
|
"github.com/netbirdio/netbird/monotime"
|
||||||
@@ -61,7 +61,7 @@ type WGIFaceOpts struct {
|
|||||||
MTU uint16
|
MTU uint16
|
||||||
MobileArgs *device.MobileIFaceArguments
|
MobileArgs *device.MobileIFaceArguments
|
||||||
TransportNet transport.Net
|
TransportNet transport.Net
|
||||||
FilterFn bind.FilterFn
|
FilterFn udpmux.FilterFn
|
||||||
DisableDNS bool
|
DisableDNS bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface {
|
|||||||
|
|
||||||
// Up configures a Wireguard interface
|
// Up configures a Wireguard interface
|
||||||
// The interface must exist before calling this method (e.g. call interface.Create() before)
|
// The interface must exist before calling this method (e.g. call interface.Create() before)
|
||||||
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package bind
|
package udpmux
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
|
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
|
||||||
@@ -16,11 +16,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type udpMuxedConnParams struct {
|
type udpMuxedConnParams struct {
|
||||||
Mux *UDPMuxDefault
|
Mux *SingleSocketUDPMux
|
||||||
AddrPool *sync.Pool
|
AddrPool *sync.Pool
|
||||||
Key string
|
Key string
|
||||||
LocalAddr net.Addr
|
LocalAddr net.Addr
|
||||||
Logger logging.LeveledLogger
|
Logger logging.LeveledLogger
|
||||||
|
CandidateID string
|
||||||
}
|
}
|
||||||
|
|
||||||
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
|
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
|
||||||
@@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *udpMuxedConn) GetCandidateID() string {
|
||||||
|
return c.params.CandidateID
|
||||||
|
}
|
||||||
|
|
||||||
func (c *udpMuxedConn) isClosed() bool {
|
func (c *udpMuxedConn) isClosed() bool {
|
||||||
select {
|
select {
|
||||||
case <-c.closedChan:
|
case <-c.closedChan:
|
||||||
64
client/iface/udpmux/doc.go
Normal file
64
client/iface/udpmux/doc.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
// Package udpmux provides a custom implementation of a UDP multiplexer
|
||||||
|
// that allows multiple logical ICE connections to share a single underlying
|
||||||
|
// UDP socket. This is based on Pion's ICE library, with modifications for
|
||||||
|
// NetBird's requirements.
|
||||||
|
//
|
||||||
|
// # Background
|
||||||
|
//
|
||||||
|
// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
|
||||||
|
// Establishment) is responsible for discovering candidate network paths
|
||||||
|
// and maintaining connectivity between peers. Each ICE connection
|
||||||
|
// normally requires a dedicated UDP socket. However, using one socket
|
||||||
|
// per candidate can be inefficient and difficult to manage.
|
||||||
|
//
|
||||||
|
// This package introduces SingleSocketUDPMux, which allows multiple ICE
|
||||||
|
// candidate connections (muxed connections) to share a single UDP socket.
|
||||||
|
// It handles demultiplexing of packets based on ICE ufrag values, STUN
|
||||||
|
// attributes, and candidate IDs.
|
||||||
|
//
|
||||||
|
// # Usage
|
||||||
|
//
|
||||||
|
// The typical flow is:
|
||||||
|
//
|
||||||
|
// 1. Create a UDP socket (net.PacketConn).
|
||||||
|
// 2. Construct Params with the socket and optional logger/net stack.
|
||||||
|
// 3. Call NewSingleSocketUDPMux(params).
|
||||||
|
// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
|
||||||
|
// to obtain a logical PacketConn.
|
||||||
|
// 5. Use the returned PacketConn just like a normal UDP connection.
|
||||||
|
//
|
||||||
|
// # STUN Message Routing Logic
|
||||||
|
//
|
||||||
|
// When a STUN packet arrives, the mux decides which connection should
|
||||||
|
// receive it using this routing logic:
|
||||||
|
//
|
||||||
|
// Primary Routing: Candidate Pair ID
|
||||||
|
// - Extract the candidate pair ID from the STUN message using
|
||||||
|
// ice.CandidatePairIDFromSTUN(msg)
|
||||||
|
// - The target candidate is the locally generated candidate that
|
||||||
|
// corresponds to the connection that should handle this STUN message
|
||||||
|
// - If found, use the target candidate ID to lookup the specific
|
||||||
|
// connection in candidateConnMap
|
||||||
|
// - Route the message directly to that connection
|
||||||
|
//
|
||||||
|
// Fallback Routing: Broadcasting
|
||||||
|
// When candidate pair ID is not available or lookup fails:
|
||||||
|
// - Collect connections from addressMap based on source address
|
||||||
|
// - Find connection using username attribute (ufrag) from STUN message
|
||||||
|
// - Remove duplicate connections from the list
|
||||||
|
// - Send the STUN message to all collected connections
|
||||||
|
//
|
||||||
|
// # Peer Reflexive Candidate Discovery
|
||||||
|
//
|
||||||
|
// When a remote peer sends a STUN message from an unknown source address
|
||||||
|
// (from a candidate that has not been exchanged via signal), the ICE
|
||||||
|
// library will:
|
||||||
|
// - Generate a new peer reflexive candidate for this source address
|
||||||
|
// - Extract or assign a candidate ID based on the STUN message attributes
|
||||||
|
// - Create a mapping between the new peer reflexive candidate ID and
|
||||||
|
// the appropriate local connection
|
||||||
|
//
|
||||||
|
// This discovery mechanism ensures that STUN messages from newly discovered
|
||||||
|
// peer reflexive candidates can be properly routed to the correct local
|
||||||
|
// connection without requiring fallback broadcasting.
|
||||||
|
package udpmux
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package bind
|
package udpmux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -22,9 +22,9 @@ import (
|
|||||||
|
|
||||||
const receiveMTU = 8192
|
const receiveMTU = 8192
|
||||||
|
|
||||||
// UDPMuxDefault is an implementation of the interface
|
// SingleSocketUDPMux is an implementation of the interface
|
||||||
type UDPMuxDefault struct {
|
type SingleSocketUDPMux struct {
|
||||||
params UDPMuxParams
|
params Params
|
||||||
|
|
||||||
closedChan chan struct{}
|
closedChan chan struct{}
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
@@ -32,6 +32,9 @@ type UDPMuxDefault struct {
|
|||||||
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
|
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
|
||||||
connsIPv4, connsIPv6 map[string]*udpMuxedConn
|
connsIPv4, connsIPv6 map[string]*udpMuxedConn
|
||||||
|
|
||||||
|
// candidateConnMap maps local candidate IDs to their corresponding connection.
|
||||||
|
candidateConnMap map[string]*udpMuxedConn
|
||||||
|
|
||||||
addressMapMu sync.RWMutex
|
addressMapMu sync.RWMutex
|
||||||
addressMap map[string][]*udpMuxedConn
|
addressMap map[string][]*udpMuxedConn
|
||||||
|
|
||||||
@@ -46,8 +49,8 @@ type UDPMuxDefault struct {
|
|||||||
|
|
||||||
const maxAddrSize = 512
|
const maxAddrSize = 512
|
||||||
|
|
||||||
// UDPMuxParams are parameters for UDPMux.
|
// Params are parameters for UDPMux.
|
||||||
type UDPMuxParams struct {
|
type Params struct {
|
||||||
Logger logging.LeveledLogger
|
Logger logging.LeveledLogger
|
||||||
UDPConn net.PacketConn
|
UDPConn net.PacketConn
|
||||||
|
|
||||||
@@ -147,17 +150,18 @@ func isZeros(ip net.IP) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPMuxDefault creates an implementation of UDPMux
|
// NewSingleSocketUDPMux creates an implementation of UDPMux
|
||||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
|
||||||
if params.Logger == nil {
|
if params.Logger == nil {
|
||||||
params.Logger = getLogger()
|
params.Logger = getLogger()
|
||||||
}
|
}
|
||||||
|
|
||||||
mux := &UDPMuxDefault{
|
mux := &SingleSocketUDPMux{
|
||||||
addressMap: map[string][]*udpMuxedConn{},
|
addressMap: map[string][]*udpMuxedConn{},
|
||||||
params: params,
|
params: params,
|
||||||
connsIPv4: make(map[string]*udpMuxedConn),
|
connsIPv4: make(map[string]*udpMuxedConn),
|
||||||
connsIPv6: make(map[string]*udpMuxedConn),
|
connsIPv6: make(map[string]*udpMuxedConn),
|
||||||
|
candidateConnMap: make(map[string]*udpMuxedConn),
|
||||||
closedChan: make(chan struct{}, 1),
|
closedChan: make(chan struct{}, 1),
|
||||||
pool: &sync.Pool{
|
pool: &sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() interface{} {
|
||||||
@@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
|||||||
return mux
|
return mux
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *UDPMuxDefault) updateLocalAddresses() {
|
func (m *SingleSocketUDPMux) updateLocalAddresses() {
|
||||||
var localAddrsForUnspecified []net.Addr
|
var localAddrsForUnspecified []net.Addr
|
||||||
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||||
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
|
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
|
||||||
} else if ok && addr.IP.IsUnspecified() {
|
} else if ok && addr.IP.IsUnspecified() {
|
||||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||||
// it will break the applications that are already using unspecified UDP connection
|
// it will break the applications that are already using unspecified UDP connection
|
||||||
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
// with SingleSocketUDPMux, so print a warn log and create a local address list for mux.
|
||||||
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||||
var networks []ice.NetworkType
|
var networks []ice.NetworkType
|
||||||
switch {
|
switch {
|
||||||
|
|
||||||
@@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() {
|
|||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalAddr returns the listening address of this UDPMuxDefault
|
// LocalAddr returns the listening address of this SingleSocketUDPMux
|
||||||
func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
func (m *SingleSocketUDPMux) LocalAddr() net.Addr {
|
||||||
return m.params.UDPConn.LocalAddr()
|
return m.params.UDPConn.LocalAddr()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetListenAddresses returns the list of addresses that this mux is listening on
|
// GetListenAddresses returns the list of addresses that this mux is listening on
|
||||||
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
|
||||||
m.updateLocalAddresses()
|
m.updateLocalAddresses()
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
|||||||
|
|
||||||
// GetConn returns a PacketConn given the connection's ufrag and network address
|
// GetConn returns a PacketConn given the connection's ufrag and network address
|
||||||
// creates the connection if an existing one can't be found
|
// creates the connection if an existing one can't be found
|
||||||
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) {
|
||||||
// don't check addr for mux using unspecified address
|
// don't check addr for mux using unspecified address
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
lenLocalAddrs := len(m.localAddrsForUnspecified)
|
lenLocalAddrs := len(m.localAddrsForUnspecified)
|
||||||
@@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
|
|||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c := m.createMuxedConn(ufrag)
|
c := m.createMuxedConn(ufrag, candidateID)
|
||||||
go func() {
|
go func() {
|
||||||
<-c.CloseChannel()
|
<-c.CloseChannel()
|
||||||
m.RemoveConnByUfrag(ufrag)
|
m.RemoveConnByUfrag(ufrag)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
m.candidateConnMap[candidateID] = c
|
||||||
|
|
||||||
if isIPv6 {
|
if isIPv6 {
|
||||||
m.connsIPv6[ufrag] = c
|
m.connsIPv6[ufrag] = c
|
||||||
} else {
|
} else {
|
||||||
@@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveConnByUfrag stops and removes the muxed packet connection
|
// RemoveConnByUfrag stops and removes the muxed packet connection
|
||||||
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
|
||||||
removedConns := make([]*udpMuxedConn, 0, 2)
|
removedConns := make([]*udpMuxedConn, 0, 2)
|
||||||
|
|
||||||
// Keep lock section small to avoid deadlock with conn lock
|
// Keep lock section small to avoid deadlock with conn lock
|
||||||
@@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
|||||||
if c, ok := m.connsIPv4[ufrag]; ok {
|
if c, ok := m.connsIPv4[ufrag]; ok {
|
||||||
delete(m.connsIPv4, ufrag)
|
delete(m.connsIPv4, ufrag)
|
||||||
removedConns = append(removedConns, c)
|
removedConns = append(removedConns, c)
|
||||||
|
delete(m.candidateConnMap, c.GetCandidateID())
|
||||||
}
|
}
|
||||||
if c, ok := m.connsIPv6[ufrag]; ok {
|
if c, ok := m.connsIPv6[ufrag]; ok {
|
||||||
delete(m.connsIPv6, ufrag)
|
delete(m.connsIPv6, ufrag)
|
||||||
removedConns = append(removedConns, c)
|
removedConns = append(removedConns, c)
|
||||||
|
delete(m.candidateConnMap, c.GetCandidateID())
|
||||||
}
|
}
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
@@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// IsClosed returns true if the mux had been closed
|
// IsClosed returns true if the mux had been closed
|
||||||
func (m *UDPMuxDefault) IsClosed() bool {
|
func (m *SingleSocketUDPMux) IsClosed() bool {
|
||||||
select {
|
select {
|
||||||
case <-m.closedChan:
|
case <-m.closedChan:
|
||||||
return true
|
return true
|
||||||
@@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close the mux, no further connections could be created
|
// Close the mux, no further connections could be created
|
||||||
func (m *UDPMuxDefault) Close() error {
|
func (m *SingleSocketUDPMux) Close() error {
|
||||||
var err error
|
var err error
|
||||||
m.closeOnce.Do(func() {
|
m.closeOnce.Do(func() {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
||||||
return m.params.UDPConn.WriteTo(buf, rAddr)
|
return m.params.UDPConn.WriteTo(buf, rAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
|
func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) {
|
||||||
if m.IsClosed() {
|
if m.IsClosed() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
|
|||||||
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
|
func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn {
|
||||||
c := newUDPMuxedConn(&udpMuxedConnParams{
|
c := newUDPMuxedConn(&udpMuxedConnParams{
|
||||||
Mux: m,
|
Mux: m,
|
||||||
Key: key,
|
Key: key,
|
||||||
AddrPool: m.pool,
|
AddrPool: m.pool,
|
||||||
LocalAddr: m.LocalAddr(),
|
LocalAddr: m.LocalAddr(),
|
||||||
Logger: m.params.Logger,
|
Logger: m.params.Logger,
|
||||||
|
CandidateID: candidateID,
|
||||||
})
|
})
|
||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
|
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
|
||||||
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
|
func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
|
||||||
|
|
||||||
remoteAddr, ok := addr.(*net.UDPAddr)
|
remoteAddr, ok := addr.(*net.UDPAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
|
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have already seen this address dispatch to the appropriate destination
|
// Try to route to specific candidate connection first
|
||||||
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
if conn := m.findCandidateConnection(msg); conn != nil {
|
||||||
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
return conn.writePacket(msg.Raw, remoteAddr)
|
||||||
// We will then forward STUN packets to each of these connections.
|
}
|
||||||
m.addressMapMu.RLock()
|
|
||||||
|
// Fallback: route to all possible connections
|
||||||
|
return m.forwardToAllConnections(msg, addr, remoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findCandidateConnection attempts to find the specific connection for a STUN message
|
||||||
|
func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
|
||||||
|
candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
} else if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
|
||||||
|
if !exists {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// forwardToAllConnections forwards STUN message to all relevant connections
|
||||||
|
func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
|
||||||
var destinationConnList []*udpMuxedConn
|
var destinationConnList []*udpMuxedConn
|
||||||
|
|
||||||
|
// Add connections from address map
|
||||||
|
m.addressMapMu.RLock()
|
||||||
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
||||||
destinationConnList = append(destinationConnList, storedConns...)
|
destinationConnList = append(destinationConnList, storedConns...)
|
||||||
}
|
}
|
||||||
m.addressMapMu.RUnlock()
|
m.addressMapMu.RUnlock()
|
||||||
|
|
||||||
var isIPv6 bool
|
if conn, ok := m.findConnectionByUsername(msg, addr); ok {
|
||||||
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
// If we have already seen this address dispatch to the appropriate destination
|
||||||
isIPv6 = true
|
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
||||||
|
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
||||||
|
// We will then forward STUN packets to each of these connections.
|
||||||
|
if !m.connectionExists(conn, destinationConnList) {
|
||||||
|
destinationConnList = append(destinationConnList, conn)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
|
// Forward to all found connections
|
||||||
// However, we can take a username attribute from the STUN message which contains ufrag.
|
|
||||||
// We can use ufrag to identify the destination conn to route packet to.
|
|
||||||
attr, stunAttrErr := msg.Get(stun.AttrUsername)
|
|
||||||
if stunAttrErr == nil {
|
|
||||||
ufrag := strings.Split(string(attr), ":")[0]
|
|
||||||
|
|
||||||
m.mu.Lock()
|
|
||||||
destinationConn := m.connsIPv4[ufrag]
|
|
||||||
if isIPv6 {
|
|
||||||
destinationConn = m.connsIPv6[ufrag]
|
|
||||||
}
|
|
||||||
|
|
||||||
if destinationConn != nil {
|
|
||||||
exists := false
|
|
||||||
for _, conn := range destinationConnList {
|
|
||||||
if conn.params.Key == destinationConn.params.Key {
|
|
||||||
exists = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
destinationConnList = append(destinationConnList, destinationConn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
|
|
||||||
// It will be discarded by the further ICE candidate logic if so.
|
|
||||||
for _, conn := range destinationConnList {
|
for _, conn := range destinationConnList {
|
||||||
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
|
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
|
||||||
log.Errorf("could not write packet: %v", err)
|
log.Errorf("could not write packet: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
|
// findConnectionByUsername finds connection using username attribute from STUN message
|
||||||
|
func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
|
||||||
|
attr, err := msg.Get(stun.AttrUsername)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
ufrag := strings.Split(string(attr), ":")[0]
|
||||||
|
isIPv6 := isIPv6Address(addr)
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
return m.getConn(ufrag, isIPv6)
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectionExists checks if a connection already exists in the list
|
||||||
|
func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
|
||||||
|
for _, conn := range conns {
|
||||||
|
if conn.params.Key == target.params.Key {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
|
||||||
if isIPv6 {
|
if isIPv6 {
|
||||||
val, ok = m.connsIPv6[ufrag]
|
val, ok = m.connsIPv6[ufrag]
|
||||||
} else {
|
} else {
|
||||||
@@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isIPv6Address(addr net.Addr) bool {
|
||||||
|
if udpAddr, ok := addr.(*net.UDPAddr); ok {
|
||||||
|
return udpAddr.IP.To4() == nil
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type bufferHolder struct {
|
type bufferHolder struct {
|
||||||
buf []byte
|
buf []byte
|
||||||
}
|
}
|
||||||
@@ -1,12 +1,12 @@
|
|||||||
//go:build !ios
|
//go:build !ios
|
||||||
|
|
||||||
package bind
|
package udpmux
|
||||||
|
|
||||||
import (
|
import (
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
|
||||||
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
|
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
|
||||||
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
|
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
|
||||||
conn.RemoveAddress(addr)
|
conn.RemoveAddress(addr)
|
||||||
7
client/iface/udpmux/mux_ios.go
Normal file
7
client/iface/udpmux/mux_ios.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package udpmux
|
||||||
|
|
||||||
|
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
|
||||||
|
// iOS doesn't support nbnet hooks, so this is a no-op
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package bind
|
package udpmux
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
|
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
|
||||||
@@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
|
|||||||
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
|
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
|
||||||
// It then passes packets to the UDPMux that does the actual connection muxing.
|
// It then passes packets to the UDPMux that does the actual connection muxing.
|
||||||
type UniversalUDPMuxDefault struct {
|
type UniversalUDPMuxDefault struct {
|
||||||
*UDPMuxDefault
|
*SingleSocketUDPMux
|
||||||
params UniversalUDPMuxParams
|
params UniversalUDPMuxParams
|
||||||
|
|
||||||
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
|
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
|
||||||
@@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
|||||||
address: params.WGAddress,
|
address: params.WGAddress,
|
||||||
}
|
}
|
||||||
|
|
||||||
udpMuxParams := UDPMuxParams{
|
udpMuxParams := Params{
|
||||||
Logger: params.Logger,
|
Logger: params.Logger,
|
||||||
UDPConn: m.params.UDPConn,
|
UDPConn: m.params.UDPConn,
|
||||||
Net: m.params.Net,
|
Net: m.params.Net,
|
||||||
}
|
}
|
||||||
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
|
m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams)
|
||||||
|
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
@@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
|
|||||||
|
|
||||||
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
|
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
|
||||||
// and return a unique connection per server.
|
// and return a unique connection per server.
|
||||||
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
|
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) {
|
||||||
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
|
return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
|
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
|
||||||
@@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
|
return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
|
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
|
||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -280,15 +280,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
if runningChan != nil {
|
if runningChan != nil {
|
||||||
select {
|
close(runningChan)
|
||||||
case runningChan <- struct{}{}:
|
runningChan = nil
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServiceViaMemory struct {
|
type ServiceViaMemory struct {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type upstreamResolver struct {
|
type upstreamResolver struct {
|
||||||
|
|||||||
@@ -29,9 +29,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
@@ -166,7 +166,7 @@ type Engine struct {
|
|||||||
|
|
||||||
wgInterface WGIface
|
wgInterface WGIface
|
||||||
|
|
||||||
udpMux *bind.UniversalUDPMuxDefault
|
udpMux *udpmux.UniversalUDPMuxDefault
|
||||||
|
|
||||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||||
networkSerial uint64
|
networkSerial uint64
|
||||||
@@ -198,6 +198,10 @@ type Engine struct {
|
|||||||
latestSyncResponse *mgmProto.SyncResponse
|
latestSyncResponse *mgmProto.SyncResponse
|
||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
flowManager nftypes.FlowManager
|
flowManager nftypes.FlowManager
|
||||||
|
|
||||||
|
// WireGuard interface monitor
|
||||||
|
wgIfaceMonitor *WGIfaceMonitor
|
||||||
|
wgIfaceMonitorWg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -341,6 +345,9 @@ func (e *Engine) Stop() error {
|
|||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop WireGuard interface monitor and wait for it to exit
|
||||||
|
e.wgIfaceMonitorWg.Wait()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -446,6 +453,8 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
return fmt.Errorf("up wg interface: %w", err)
|
return fmt.Errorf("up wg interface: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// if inbound conns are blocked there is no need to create the ACL manager
|
// if inbound conns are blocked there is no need to create the ACL manager
|
||||||
if e.firewall != nil && !e.config.BlockInbound {
|
if e.firewall != nil && !e.config.BlockInbound {
|
||||||
e.acl = acl.NewDefaultManager(e.firewall)
|
e.acl = acl.NewDefaultManager(e.firewall)
|
||||||
@@ -461,7 +470,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
StunTurn: &e.stunTurn,
|
StunTurn: &e.stunTurn,
|
||||||
InterfaceBlackList: e.config.IFaceBlackList,
|
InterfaceBlackList: e.config.IFaceBlackList,
|
||||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||||
UDPMux: e.udpMux.UDPMuxDefault,
|
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||||
UDPMuxSrflx: e.udpMux,
|
UDPMuxSrflx: e.udpMux,
|
||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
}
|
}
|
||||||
@@ -477,6 +486,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
// starting network monitor at the very last to avoid disruptions
|
// starting network monitor at the very last to avoid disruptions
|
||||||
e.startNetworkMonitor()
|
e.startNetworkMonitor()
|
||||||
|
|
||||||
|
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||||
|
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||||
|
e.wgIfaceMonitorWg.Add(1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer e.wgIfaceMonitorWg.Done()
|
||||||
|
|
||||||
|
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||||
|
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||||
|
e.restartEngine()
|
||||||
|
} else if err != nil {
|
||||||
|
log.Warnf("WireGuard interface monitor: %s", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -949,7 +974,6 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.LazyConnectionEnabled,
|
e.config.LazyConnectionEnabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
// err = e.mgmClient.Sync(info, e.handleSync)
|
|
||||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// happens if management is unavailable for a long time.
|
// happens if management is unavailable for a long time.
|
||||||
@@ -960,7 +984,7 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
}
|
}
|
||||||
log.Debugf("stopped receiving updates from Management Service")
|
log.Debugf("stopped receiving updates from Management Service")
|
||||||
}()
|
}()
|
||||||
log.Debugf("connecting to Management Service updates stream")
|
log.Infof("connecting to Management Service updates stream")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error {
|
func (e *Engine) updateSTUNs(stuns []*mgmProto.HostConfig) error {
|
||||||
@@ -1327,7 +1351,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
StunTurn: &e.stunTurn,
|
StunTurn: &e.stunTurn,
|
||||||
InterfaceBlackList: e.config.IFaceBlackList,
|
InterfaceBlackList: e.config.IFaceBlackList,
|
||||||
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
|
||||||
UDPMux: e.udpMux.UDPMuxDefault,
|
UDPMux: e.udpMux.SingleSocketUDPMux,
|
||||||
UDPMuxSrflx: e.udpMux,
|
UDPMuxSrflx: e.udpMux,
|
||||||
NATExternalIPs: e.parseNATExternalIPMappings(),
|
NATExternalIPs: e.parseNATExternalIPMappings(),
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -19,21 +19,18 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
@@ -45,9 +42,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/peers"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@@ -85,7 +85,7 @@ type MockWGIface struct {
|
|||||||
NameFunc func() string
|
NameFunc func() string
|
||||||
AddressFunc func() wgaddr.Address
|
AddressFunc func() wgaddr.Address
|
||||||
ToInterfaceFunc func() *net.Interface
|
ToInterfaceFunc func() *net.Interface
|
||||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
UpFunc func() (*udpmux.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddrFunc func(newAddr string) error
|
UpdateAddrFunc func(newAddr string) error
|
||||||
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
RemovePeerFunc func(peerKey string) error
|
RemovePeerFunc func(peerKey string) error
|
||||||
@@ -135,7 +135,7 @@ func (m *MockWGIface) ToInterface() *net.Interface {
|
|||||||
return m.ToInterfaceFunc()
|
return m.ToInterfaceFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
|
func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||||
return m.UpFunc()
|
return m.UpFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -414,7 +414,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
|
engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
|
||||||
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
|
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)
|
||||||
@@ -1555,7 +1555,11 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
|
||||||
|
permissionsManager := permissions.NewManager(store)
|
||||||
|
peersManager := peers.NewManager(store, permissionsManager)
|
||||||
|
|
||||||
|
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
|
||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -1572,7 +1576,6 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
Return(&types.ExtraSettings{}, nil).
|
Return(&types.ExtraSettings{}, nil).
|
||||||
AnyTimes()
|
AnyTimes()
|
||||||
|
|
||||||
permissionsManager := permissions.NewManager(store)
|
|
||||||
groupsManager := groups.NewManagerMock()
|
groupsManager := groups.NewManagerMock()
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
"github.com/netbirdio/netbird/monotime"
|
"github.com/netbirdio/netbird/monotime"
|
||||||
@@ -24,7 +24,7 @@ type wgIfaceBase interface {
|
|||||||
Name() string
|
Name() string
|
||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
ToInterface() *net.Interface
|
ToInterface() *net.Interface
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*udpmux.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(newAddr string) error
|
UpdateAddr(newAddr string) error
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/ti-mo/netfilter"
|
"github.com/ti-mo/netfilter"
|
||||||
|
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultChannelSize = 100
|
const defaultChannelSize = 100
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -174,7 +173,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
|||||||
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
|
||||||
|
|
||||||
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
|
||||||
if os.Getenv("NB_FORCE_RELAY") != "true" {
|
if !isForceRelayed() {
|
||||||
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
14
client/internal/peer/env.go
Normal file
14
client/internal/peer/env.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||||
|
)
|
||||||
|
|
||||||
|
func isForceRelayed() bool {
|
||||||
|
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||||
|
}
|
||||||
@@ -43,13 +43,6 @@ type OfferAnswer struct {
|
|||||||
SessionID *ICESessionID
|
SessionID *ICESessionID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oa *OfferAnswer) SessionIDString() string {
|
|
||||||
if oa.SessionID == nil {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
return oa.SessionID.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
type Handshaker struct {
|
type Handshaker struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
@@ -57,7 +50,7 @@ type Handshaker struct {
|
|||||||
signaler *Signaler
|
signaler *Signaler
|
||||||
ice *WorkerICE
|
ice *WorkerICE
|
||||||
relay *WorkerRelay
|
relay *WorkerRelay
|
||||||
onNewOfferListeners []func(*OfferAnswer)
|
onNewOfferListeners []*OfferListener
|
||||||
|
|
||||||
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
|
||||||
remoteOffersCh chan OfferAnswer
|
remoteOffersCh chan OfferAnswer
|
||||||
@@ -78,7 +71,8 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
|
||||||
h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
|
l := NewOfferListener(offer)
|
||||||
|
h.onNewOfferListeners = append(h.onNewOfferListeners, l)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshaker) Listen(ctx context.Context) {
|
func (h *Handshaker) Listen(ctx context.Context) {
|
||||||
@@ -91,13 +85,13 @@ func (h *Handshaker) Listen(ctx context.Context) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for _, listener := range h.onNewOfferListeners {
|
for _, listener := range h.onNewOfferListeners {
|
||||||
listener(&remoteOfferAnswer)
|
listener.Notify(&remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||||
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
case remoteOfferAnswer := <-h.remoteAnswerCh:
|
||||||
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
|
||||||
for _, listener := range h.onNewOfferListeners {
|
for _, listener := range h.onNewOfferListeners {
|
||||||
listener(&remoteOfferAnswer)
|
listener.Notify(&remoteOfferAnswer)
|
||||||
}
|
}
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
h.log.Infof("stop listening for remote offers and answers")
|
h.log.Infof("stop listening for remote offers and answers")
|
||||||
|
|||||||
62
client/internal/peer/handshaker_listener.go
Normal file
62
client/internal/peer/handshaker_listener.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type callbackFunc func(remoteOfferAnswer *OfferAnswer)
|
||||||
|
|
||||||
|
func (oa *OfferAnswer) SessionIDString() string {
|
||||||
|
if oa.SessionID == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
return oa.SessionID.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
type OfferListener struct {
|
||||||
|
fn callbackFunc
|
||||||
|
running bool
|
||||||
|
latest *OfferAnswer
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOfferListener(fn callbackFunc) *OfferListener {
|
||||||
|
return &OfferListener{
|
||||||
|
fn: fn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *OfferListener) Notify(remoteOfferAnswer *OfferAnswer) {
|
||||||
|
o.mu.Lock()
|
||||||
|
defer o.mu.Unlock()
|
||||||
|
|
||||||
|
// Store the latest offer
|
||||||
|
o.latest = remoteOfferAnswer
|
||||||
|
|
||||||
|
// If already running, the running goroutine will pick up this latest value
|
||||||
|
if o.running {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start processing
|
||||||
|
o.running = true
|
||||||
|
|
||||||
|
// Process in a goroutine to avoid blocking the caller
|
||||||
|
go func(remoteOfferAnswer *OfferAnswer) {
|
||||||
|
for {
|
||||||
|
o.fn(remoteOfferAnswer)
|
||||||
|
|
||||||
|
o.mu.Lock()
|
||||||
|
if o.latest == nil {
|
||||||
|
// No more work to do
|
||||||
|
o.running = false
|
||||||
|
o.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
remoteOfferAnswer = o.latest
|
||||||
|
// Clear the latest to mark it as being processed
|
||||||
|
o.latest = nil
|
||||||
|
o.mu.Unlock()
|
||||||
|
}
|
||||||
|
}(remoteOfferAnswer)
|
||||||
|
}
|
||||||
39
client/internal/peer/handshaker_listener_test.go
Normal file
39
client/internal/peer/handshaker_listener_test.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_newOfferListener(t *testing.T) {
|
||||||
|
dummyOfferAnswer := &OfferAnswer{}
|
||||||
|
runChan := make(chan struct{}, 10)
|
||||||
|
|
||||||
|
longRunningFn := func(remoteOfferAnswer *OfferAnswer) {
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
runChan <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
hl := NewOfferListener(longRunningFn)
|
||||||
|
|
||||||
|
hl.Notify(dummyOfferAnswer)
|
||||||
|
hl.Notify(dummyOfferAnswer)
|
||||||
|
hl.Notify(dummyOfferAnswer)
|
||||||
|
|
||||||
|
// Wait for exactly 2 callbacks
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
select {
|
||||||
|
case <-runChan:
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("Timeout waiting for callback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no additional callbacks happen
|
||||||
|
select {
|
||||||
|
case <-runChan:
|
||||||
|
t.Fatal("Unexpected additional callback")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
t.Log("Correctly received exactly 2 callbacks")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -33,6 +33,7 @@ type WGWatcher struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
ctxLock sync.Mutex
|
ctxLock sync.Mutex
|
||||||
|
enabledTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
|
||||||
@@ -48,6 +49,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
|||||||
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
|
||||||
w.log.Debugf("enable WireGuard watcher")
|
w.log.Debugf("enable WireGuard watcher")
|
||||||
w.ctxLock.Lock()
|
w.ctxLock.Lock()
|
||||||
|
w.enabledTime = time.Now()
|
||||||
|
|
||||||
if w.ctx != nil && w.ctx.Err() == nil {
|
if w.ctx != nil && w.ctx.Err() == nil {
|
||||||
w.log.Errorf("WireGuard watcher already enabled")
|
w.log.Errorf("WireGuard watcher already enabled")
|
||||||
@@ -101,6 +103,11 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
|
|||||||
onDisconnectedFn()
|
onDisconnectedFn()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if lastHandshake.IsZero() {
|
||||||
|
elapsed := handshake.Sub(w.enabledTime).Seconds()
|
||||||
|
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
|
||||||
|
}
|
||||||
|
|
||||||
lastHandshake = *handshake
|
lastHandshake = *handshake
|
||||||
|
|
||||||
resetTime := time.Until(handshake.Add(checkPeriod))
|
resetTime := time.Until(handshake.Add(checkPeriod))
|
||||||
|
|||||||
@@ -9,11 +9,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pion/ice/v4"
|
"github.com/pion/ice/v4"
|
||||||
"github.com/pion/stun/v2"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
"github.com/netbirdio/netbird/client/internal/peer/conntype"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
@@ -55,10 +54,6 @@ type WorkerICE struct {
|
|||||||
sessionID ICESessionID
|
sessionID ICESessionID
|
||||||
muxAgent sync.Mutex
|
muxAgent sync.Mutex
|
||||||
|
|
||||||
StunTurn []*stun.URI
|
|
||||||
|
|
||||||
sentExtraSrflx bool
|
|
||||||
|
|
||||||
localUfrag string
|
localUfrag string
|
||||||
localPwd string
|
localPwd string
|
||||||
|
|
||||||
@@ -122,7 +117,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
w.log.Warnf("failed to close ICE agent: %s", err)
|
w.log.Warnf("failed to close ICE agent: %s", err)
|
||||||
}
|
}
|
||||||
w.agent = nil
|
w.agent = nil
|
||||||
// todo consider to switch to Relay connection while establishing a new ICE connection
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var preferredCandidateTypes []ice.CandidateType
|
var preferredCandidateTypes []ice.CandidateType
|
||||||
@@ -140,7 +134,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
w.muxAgent.Unlock()
|
w.muxAgent.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.sentExtraSrflx = false
|
|
||||||
w.agent = agent
|
w.agent = agent
|
||||||
w.agentDialerCancel = dialerCancel
|
w.agentDialerCancel = dialerCancel
|
||||||
w.agentConnecting = true
|
w.agentConnecting = true
|
||||||
@@ -167,6 +160,21 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
|
|||||||
w.log.Errorf("error while handling remote candidate")
|
w.log.Errorf("error while handling remote candidate")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if shouldAddExtraCandidate(candidate) {
|
||||||
|
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
||||||
|
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
||||||
|
extraSrflx, err := extraSrflxCandidate(candidate)
|
||||||
|
if err != nil {
|
||||||
|
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil {
|
||||||
|
w.log.Errorf("error while handling remote candidate")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
|
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
|
||||||
@@ -210,7 +218,9 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil {
|
if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) {
|
||||||
|
w.onICESelectedCandidatePair(agent, c1, c2)
|
||||||
|
}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -328,7 +338,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
|
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault)
|
||||||
if !ok {
|
if !ok {
|
||||||
w.log.Warn("invalid udp mux conversion")
|
w.log.Warn("invalid udp mux conversion")
|
||||||
return
|
return
|
||||||
@@ -355,48 +365,19 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
|
|||||||
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
|
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if !w.shouldSendExtraSrflxCandidate(candidate) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
|
|
||||||
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
|
|
||||||
extraSrflx, err := extraSrflxCandidate(candidate)
|
|
||||||
if err != nil {
|
|
||||||
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.sentExtraSrflx = true
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
|
|
||||||
if err != nil {
|
|
||||||
w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
|
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
|
||||||
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
|
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
|
||||||
w.config.Key)
|
w.config.Key)
|
||||||
|
|
||||||
w.muxAgent.Lock()
|
pairStat, ok := agent.GetSelectedCandidatePairStats()
|
||||||
|
if !ok {
|
||||||
pair, err := w.agent.GetSelectedCandidatePair()
|
w.log.Warnf("failed to get selected candidate pair stats")
|
||||||
if err != nil {
|
|
||||||
w.log.Warnf("failed to get selected candidate pair: %s", err)
|
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if pair == nil {
|
|
||||||
w.log.Warnf("selected candidate pair is nil, cannot proceed")
|
|
||||||
w.muxAgent.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.muxAgent.Unlock()
|
|
||||||
|
|
||||||
duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second))
|
duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second))
|
||||||
if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
|
if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
|
||||||
w.log.Debugf("failed to update latency for peer: %s", err)
|
w.log.Debugf("failed to update latency for peer: %s", err)
|
||||||
return
|
return
|
||||||
@@ -410,7 +391,10 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
|||||||
case ice.ConnectionStateConnected:
|
case ice.ConnectionStateConnected:
|
||||||
w.lastKnownState = ice.ConnectionStateConnected
|
w.lastKnownState = ice.ConnectionStateConnected
|
||||||
return
|
return
|
||||||
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected:
|
case ice.ConnectionStateFailed, ice.ConnectionStateDisconnected, ice.ConnectionStateClosed:
|
||||||
|
// ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to
|
||||||
|
// notify the conn.onICEStateDisconnected changes to update the current used priority
|
||||||
|
|
||||||
if w.lastKnownState == ice.ConnectionStateConnected {
|
if w.lastKnownState == ice.ConnectionStateConnected {
|
||||||
w.lastKnownState = ice.ConnectionStateDisconnected
|
w.lastKnownState = ice.ConnectionStateDisconnected
|
||||||
w.conn.onICEStateDisconnected()
|
w.conn.onICEStateDisconnected()
|
||||||
@@ -422,22 +406,31 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
|
|
||||||
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
||||||
isControlling := w.config.LocalKey > w.config.Key
|
if isController(w.config) {
|
||||||
if isControlling {
|
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||||
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
|
||||||
} else {
|
} else {
|
||||||
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldAddExtraCandidate(candidate ice.Candidate) bool {
|
||||||
|
if candidate.Type() != ice.CandidateTypeServerReflexive {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if candidate.Port() == candidate.RelatedAddress().Port {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates
|
||||||
|
// in newer version we generate locally the extra candidate
|
||||||
|
if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
|
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
|
||||||
relatedAdd := candidate.RelatedAddress()
|
relatedAdd := candidate.RelatedAddress()
|
||||||
ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
|
||||||
@@ -453,6 +446,10 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, e := range candidate.Extensions() {
|
for _, e := range candidate.Extensions() {
|
||||||
|
// overwrite the original candidate ID with the new one to avoid candidate duplication
|
||||||
|
if e.Key == ice.ExtensionKeyCandidateID {
|
||||||
|
e.Value = candidate.ID()
|
||||||
|
}
|
||||||
if err := ec.AddExtension(e); err != nil {
|
if err := ec.AddExtension(e); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProbeResult holds the info about the result of a relay probe request
|
// ProbeResult holds the info about the result of a relay probe request
|
||||||
|
|||||||
@@ -36,9 +36,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
|||||||
notifier := notifier.NewNotifier()
|
notifier := notifier.NewNotifier()
|
||||||
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" && config.WGInterface != nil {
|
||||||
|
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
|
||||||
|
}
|
||||||
|
|
||||||
dm := &DefaultManager{
|
dm := &DefaultManager{
|
||||||
ctx: mCTX,
|
ctx: mCTX,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
@@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil {
|
||||||
log.Warnf("Failed cleaning up routing: %v", err)
|
log.Warnf("Failed cleaning up routing: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error {
|
|||||||
|
|
||||||
ips := resolveURLsToIPs(initialAddresses)
|
ips := resolveURLsToIPs(initialAddresses)
|
||||||
|
|
||||||
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
|
if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil {
|
||||||
return fmt.Errorf("setup routing: %w", err)
|
return fmt.Errorf("setup routing: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
|
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
|
||||||
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
|
if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil {
|
||||||
log.Errorf("Error cleaning up routing: %v", err)
|
log.Errorf("Error cleaning up routing: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Info("Routing cleanup complete")
|
log.Info("Routing cleanup complete")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
nbnet.SetVPNInterfaceName("")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
|
|||||||
@@ -12,11 +12,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
package systemops
|
package systemops
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@@ -22,7 +21,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
"github.com/netbirdio/netbird/client/net/hooks"
|
||||||
)
|
)
|
||||||
|
|
||||||
const localSubnetsCacheTTL = 15 * time.Minute
|
const localSubnetsCacheTTL = 15 * time.Minute
|
||||||
@@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Remove hooks selectively
|
hooks.RemoveWriteHooks()
|
||||||
nbnet.RemoveDialerHooks()
|
hooks.RemoveCloseHooks()
|
||||||
nbnet.RemoveListenerHooks()
|
hooks.RemoveAddressRemoveHooks()
|
||||||
|
|
||||||
if err := r.refCounter.Flush(); err != nil {
|
if err := r.refCounter.Flush(); err != nil {
|
||||||
return fmt.Errorf("flush route manager: %w", err)
|
return fmt.Errorf("flush route manager: %w", err)
|
||||||
@@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||||
prefix, err := util.GetPrefixFromIP(ip)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
|
||||||
return fmt.Errorf("adding route reference: %v", err)
|
return fmt.Errorf("adding route reference: %v", err)
|
||||||
}
|
}
|
||||||
@@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
afterHook := func(connID nbnet.ConnectionID) error {
|
afterHook := func(connID hooks.ConnectionID) error {
|
||||||
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
|
||||||
return fmt.Errorf("remove route reference: %w", err)
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
}
|
}
|
||||||
@@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
|||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
for _, ip := range initAddresses {
|
for _, ip := range initAddresses {
|
||||||
if err := beforeHook("init", ip); err != nil {
|
prefix, err := util.GetPrefixFromIP(ip)
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := beforeHook("init", prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
hooks.AddWriteHook(beforeHook)
|
||||||
if ctx.Err() != nil {
|
hooks.AddCloseHook(afterHook)
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error {
|
||||||
for _, ip := range resolvedIPs {
|
|
||||||
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
|
||||||
return afterHook(connID)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
|
||||||
return beforeHook(connID, ip.IP)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
|
||||||
return afterHook(connID)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
|
||||||
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
||||||
return fmt.Errorf("remove route reference: %w", err)
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dialer interface {
|
type dialer interface {
|
||||||
@@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) {
|
|||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
err := r.SetupRouting(nil, nil)
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||||
})
|
})
|
||||||
|
|
||||||
intf, err := net.InterfaceByName(wgInterface.Name())
|
intf, err := net.InterfaceByName(wgInterface.Name())
|
||||||
@@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
|||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
err := r.SetupRouting(nil, nil)
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||||
})
|
})
|
||||||
|
|
||||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||||
@@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
err := r.SetupRouting(nil, nil)
|
advancedRouting := nbnet.AdvancedRouting()
|
||||||
|
err := r.SetupRouting(nil, nil, advancedRouting)
|
||||||
require.NoError(t, err, "setupRouting should not return err")
|
require.NoError(t, err, "setupRouting should not return err")
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
|
||||||
})
|
})
|
||||||
|
|
||||||
index, err := net.InterfaceByName(wgInterface.Name())
|
index, err := net.InterfaceByName(wgInterface.Name())
|
||||||
|
|||||||
@@ -12,14 +12,14 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
r.prefixes = make(map[netip.Prefix]struct{})
|
r.prefixes = make(map[netip.Prefix]struct{})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IPRule contains IP rule information for debugging
|
// IPRule contains IP rule information for debugging
|
||||||
@@ -94,15 +94,15 @@ func getSetupRules() []ruleParams {
|
|||||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||||
// This table is where a default route or other specific routes received from the management server are configured,
|
// This table is where a default route or other specific routes received from the management server are configured,
|
||||||
// enabling VPN connectivity.
|
// enabling VPN connectivity.
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) {
|
||||||
if !nbnet.AdvancedRouting() {
|
if !advancedRouting {
|
||||||
log.Infof("Using legacy routing setup")
|
log.Infof("Using legacy routing setup")
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
|
if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil {
|
||||||
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
||||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
||||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
if !nbnet.AdvancedRouting() {
|
if !advancedRouting {
|
||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -20,11 +20,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PacketExpectation struct {
|
type PacketExpectation struct {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
@@ -19,9 +20,16 @@ import (
|
|||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const InfiniteLifetime = 0xffffffff
|
func init() {
|
||||||
|
nbnet.GetBestInterfaceFunc = GetBestInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
InfiniteLifetime = 0xffffffff
|
||||||
|
)
|
||||||
|
|
||||||
type RouteUpdateType int
|
type RouteUpdateType int
|
||||||
|
|
||||||
@@ -77,6 +85,14 @@ type MIB_IPFORWARD_TABLE2 struct {
|
|||||||
Table [1]MIB_IPFORWARD_ROW2 // Flexible array member
|
Table [1]MIB_IPFORWARD_ROW2 // Flexible array member
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// candidateRoute represents a potential route for selection during route lookup
|
||||||
|
type candidateRoute struct {
|
||||||
|
interfaceIndex uint32
|
||||||
|
prefixLength uint8
|
||||||
|
routeMetric uint32
|
||||||
|
interfaceMetric int
|
||||||
|
}
|
||||||
|
|
||||||
// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix
|
// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix
|
||||||
type IP_ADDRESS_PREFIX struct {
|
type IP_ADDRESS_PREFIX struct {
|
||||||
Prefix SOCKADDR_INET
|
Prefix SOCKADDR_INET
|
||||||
@@ -177,11 +193,20 @@ const (
|
|||||||
RouteDeleted
|
RouteDeleted
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
|
if advancedRouting {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Using legacy routing setup with ref counters")
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||||
|
if advancedRouting {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return r.cleanupRefCounter(stateManager)
|
return r.cleanupRefCounter(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -635,10 +660,7 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) {
|
|||||||
|
|
||||||
func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) {
|
func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) {
|
||||||
if table != nil {
|
if table != nil {
|
||||||
ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
|
_, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
|
||||||
if ret != 0 {
|
|
||||||
log.Warnf("FreeMibTable failed with return code: %d", ret)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -652,8 +674,7 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute {
|
|||||||
entryPtr := basePtr + uintptr(i)*entrySize
|
entryPtr := basePtr + uintptr(i)*entrySize
|
||||||
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
|
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
|
||||||
|
|
||||||
detailed := buildWindowsDetailedRoute(entry)
|
if detailed := buildWindowsDetailedRoute(entry); detailed != nil {
|
||||||
if detailed != nil {
|
|
||||||
detailedRoutes = append(detailedRoutes, *detailed)
|
detailedRoutes = append(detailedRoutes, *detailed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -802,6 +823,46 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr {
|
|||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseCandidatesFromTable extracts all matching candidate routes from the routing table
|
||||||
|
func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute {
|
||||||
|
var candidates []candidateRoute
|
||||||
|
entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{})
|
||||||
|
basePtr := uintptr(unsafe.Pointer(&table.Table[0]))
|
||||||
|
|
||||||
|
for i := uint32(0); i < table.NumEntries; i++ {
|
||||||
|
entryPtr := basePtr + uintptr(i)*entrySize
|
||||||
|
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
|
||||||
|
|
||||||
|
if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil {
|
||||||
|
candidates = append(candidates, *candidate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry
|
||||||
|
// Returns nil if the route doesn't match the destination or should be skipped
|
||||||
|
func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute {
|
||||||
|
if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex))
|
||||||
|
if !destPrefix.IsValid() || !destPrefix.Contains(dest) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family)
|
||||||
|
|
||||||
|
return &candidateRoute{
|
||||||
|
interfaceIndex: entry.InterfaceIndex,
|
||||||
|
prefixLength: entry.DestinationPrefix.PrefixLength,
|
||||||
|
routeMetric: entry.Metric,
|
||||||
|
interfaceMetric: interfaceMetric,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getInterfaceMetric retrieves the interface metric for a given interface and address family
|
// getInterfaceMetric retrieves the interface metric for a given interface and address family
|
||||||
func getInterfaceMetric(interfaceIndex uint32, family int16) int {
|
func getInterfaceMetric(interfaceIndex uint32, family int16) int {
|
||||||
if interfaceIndex == 0 {
|
if interfaceIndex == 0 {
|
||||||
@@ -821,6 +882,76 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int {
|
|||||||
return int(ipInterfaceRow.Metric)
|
return int(ipInterfaceRow.Metric)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric
|
||||||
|
func sortRouteCandidates(candidates []candidateRoute) {
|
||||||
|
sort.Slice(candidates, func(i, j int) bool {
|
||||||
|
if candidates[i].prefixLength != candidates[j].prefixLength {
|
||||||
|
return candidates[i].prefixLength > candidates[j].prefixLength
|
||||||
|
}
|
||||||
|
if candidates[i].routeMetric != candidates[j].routeMetric {
|
||||||
|
return candidates[i].routeMetric < candidates[j].routeMetric
|
||||||
|
}
|
||||||
|
return candidates[i].interfaceMetric < candidates[j].interfaceMetric
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBestInterface finds the best interface for reaching a destination,
|
||||||
|
// excluding the VPN interface to avoid routing loops.
|
||||||
|
//
|
||||||
|
// Route selection priority:
|
||||||
|
// 1. Longest prefix match (most specific route)
|
||||||
|
// 2. Lowest route metric
|
||||||
|
// 3. Lowest interface metric
|
||||||
|
func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) {
|
||||||
|
var skipInterfaceIndex int
|
||||||
|
if vpnIntf != "" {
|
||||||
|
if iface, err := net.InterfaceByName(vpnIntf); err == nil {
|
||||||
|
skipInterfaceIndex = iface.Index
|
||||||
|
} else {
|
||||||
|
// not critical, if we cannot get ahold of the interface then we won't need to skip it
|
||||||
|
log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
table, err := getWindowsRoutingTable()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get routing table: %w", err)
|
||||||
|
}
|
||||||
|
defer freeWindowsRoutingTable(table)
|
||||||
|
|
||||||
|
candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex)
|
||||||
|
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil, fmt.Errorf("no route to %s", dest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort routes: prefix length -> route metric -> interface metric
|
||||||
|
sortRouteCandidates(candidates)
|
||||||
|
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex))
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface.Flags&net.FlagUp == 0 {
|
||||||
|
log.Debugf("interface %s is down, trying next route", iface.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d",
|
||||||
|
dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric)
|
||||||
|
return iface, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("no usable interface found for %s", dest)
|
||||||
|
}
|
||||||
|
|
||||||
// formatRouteAge formats the route age in seconds to a human-readable string
|
// formatRouteAge formats the route age in seconds to a human-readable string
|
||||||
func formatRouteAge(ageSeconds uint32) string {
|
func formatRouteAge(ageSeconds uint32) string {
|
||||||
if ageSeconds == 0 {
|
if ageSeconds == 0 {
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -12,18 +12,8 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
|
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
addr = addr.Unmap()
|
addr = addr.Unmap()
|
||||||
|
prefix := netip.PrefixFrom(addr, addr.BitLen())
|
||||||
var prefixLength int
|
|
||||||
switch {
|
|
||||||
case addr.Is4():
|
|
||||||
prefixLength = 32
|
|
||||||
case addr.Is6():
|
|
||||||
prefixLength = 128
|
|
||||||
default:
|
|
||||||
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
|
||||||
return prefix, nil
|
return prefix, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Dial connects to the address on the named network.
|
// Dial connects to the address on the named network.
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ListenPacket listens for incoming packets on the given network and address.
|
// ListenPacket listens for incoming packets on the given network and address.
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
|
|||||||
if netstack.IsEnabled() {
|
if netstack.IsEnabled() {
|
||||||
n.iFaceDiscover = pionDiscover{}
|
n.iFaceDiscover = pionDiscover{}
|
||||||
} else {
|
} else {
|
||||||
newMobileIFaceDiscover(iFaceDiscover)
|
n.iFaceDiscover = newMobileIFaceDiscover(iFaceDiscover)
|
||||||
}
|
}
|
||||||
return n, n.UpdateInterfaces()
|
return n, n.UpdateInterfaces()
|
||||||
}
|
}
|
||||||
|
|||||||
98
client/internal/wg_iface_monitor.go
Normal file
98
client/internal/wg_iface_monitor.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
||||||
|
// if the interface is deleted externally while the engine is running.
|
||||||
|
type WGIfaceMonitor struct {
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWGIfaceMonitor creates a new WGIfaceMonitor instance.
|
||||||
|
func NewWGIfaceMonitor() *WGIfaceMonitor {
|
||||||
|
return &WGIfaceMonitor{
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins monitoring the WireGuard interface.
|
||||||
|
// It relies on the provided context cancellation to stop.
|
||||||
|
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
||||||
|
defer close(m.done)
|
||||||
|
|
||||||
|
// Skip on mobile platforms as they handle interface lifecycle differently
|
||||||
|
if runtime.GOOS == "android" || runtime.GOOS == "ios" {
|
||||||
|
log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS)
|
||||||
|
return false, errors.New("not supported on mobile platforms")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ifaceName == "" {
|
||||||
|
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
||||||
|
return false, errors.New("empty interface name")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get initial interface index to track the specific interface instance
|
||||||
|
expectedIndex, err := getInterfaceIndex(ifaceName)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName)
|
||||||
|
return false, fmt.Errorf("interface %s not found: %w", ifaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(2 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||||
|
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
|
||||||
|
case <-ticker.C:
|
||||||
|
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||||
|
if err != nil {
|
||||||
|
// Interface was deleted
|
||||||
|
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||||
|
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if interface index changed (interface was recreated)
|
||||||
|
if currentIndex != expectedIndex {
|
||||||
|
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||||
|
ifaceName, expectedIndex, currentIndex)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// getInterfaceIndex returns the index of a network interface by name.
|
||||||
|
// Returns an error if the interface is not found.
|
||||||
|
func getInterfaceIndex(name string) (int, error) {
|
||||||
|
if name == "" {
|
||||||
|
return 0, fmt.Errorf("empty interface name")
|
||||||
|
}
|
||||||
|
ifi, err := net.InterfaceByName(name)
|
||||||
|
if err != nil {
|
||||||
|
// Check if it's specifically a "not found" error
|
||||||
|
if errors.Is(err, &net.OpError{}) {
|
||||||
|
// On some systems, this might be a "not found" error
|
||||||
|
return 0, fmt.Errorf("interface not found: %w", err)
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("failed to lookup interface: %w", err)
|
||||||
|
}
|
||||||
|
if ifi == nil {
|
||||||
|
return 0, fmt.Errorf("interface not found")
|
||||||
|
}
|
||||||
|
return ifi.Index, nil
|
||||||
|
}
|
||||||
49
client/net/conn.go
Normal file
49
client/net/conn.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/net/hooks"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conn wraps a net.Conn to override the Close method
|
||||||
|
type Conn struct {
|
||||||
|
net.Conn
|
||||||
|
ID hooks.ConnectionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||||
|
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
|
||||||
|
func (c *Conn) Close() error {
|
||||||
|
return closeConn(c.ID, c.Conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCPConn wraps net.TCPConn to override its Close method to include hook functionality.
|
||||||
|
type TCPConn struct {
|
||||||
|
*net.TCPConn
|
||||||
|
ID hooks.ConnectionID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
|
||||||
|
func (c *TCPConn) Close() error {
|
||||||
|
return closeConn(c.ID, c.TCPConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeConn is a helper function to close connections and execute close hooks.
|
||||||
|
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
|
||||||
|
err := conn.Close()
|
||||||
|
|
||||||
|
closeHooks := hooks.GetCloseHooks()
|
||||||
|
for _, hook := range closeHooks {
|
||||||
|
if err := hook(id); err != nil {
|
||||||
|
log.Errorf("Error executing close hook: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
82
client/net/dial.go
Normal file
82
client/net/dial.go
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.DialUDP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := NewDialer()
|
||||||
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
conn, err := dialer.Dial(network, raddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c := conn.(type) {
|
||||||
|
case *net.UDPConn:
|
||||||
|
// Advanced routing: plain connection
|
||||||
|
return c, nil
|
||||||
|
case *Conn:
|
||||||
|
// Legacy routing: wrapped connection preserves close hooks
|
||||||
|
udpConn, ok := c.Conn.(*net.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn)
|
||||||
|
}
|
||||||
|
return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected connection type: %T", conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.DialTCP(network, laddr, raddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := NewDialer()
|
||||||
|
dialer.LocalAddr = laddr
|
||||||
|
|
||||||
|
conn, err := dialer.Dial(network, raddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c := conn.(type) {
|
||||||
|
case *net.TCPConn:
|
||||||
|
// Advanced routing: plain connection
|
||||||
|
return c, nil
|
||||||
|
case *Conn:
|
||||||
|
// Legacy routing: wrapped connection preserves close hooks
|
||||||
|
tcpConn, ok := c.Conn.(*net.TCPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn)
|
||||||
|
}
|
||||||
|
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected connection type: %T", conn)
|
||||||
|
}
|
||||||
@@ -16,6 +16,5 @@ func NewDialer() *Dialer {
|
|||||||
Dialer: &net.Dialer{},
|
Dialer: &net.Dialer{},
|
||||||
}
|
}
|
||||||
dialer.init()
|
dialer.init()
|
||||||
|
|
||||||
return dialer
|
return dialer
|
||||||
}
|
}
|
||||||
87
client/net/dialer_dial.go
Normal file
87
client/net/dialer_dial.go
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
|
"github.com/netbirdio/netbird/client/net/hooks"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
|
||||||
|
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
log.Debugf("Dialing %s %s", network, address)
|
||||||
|
|
||||||
|
if CustomRoutingDisabled() || AdvancedRouting() {
|
||||||
|
return d.Dialer.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
connID := hooks.GenerateConnID()
|
||||||
|
if err := callDialerHooks(ctx, connID, address, d.Resolver); err != nil {
|
||||||
|
log.Errorf("Failed to call dialer hooks: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap the connection in Conn to handle Close with hooks
|
||||||
|
return &Conn{Conn: conn, ID: connID}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial wraps the net.Dialer's Dial method to use the custom connection
|
||||||
|
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
||||||
|
return d.DialContext(context.Background(), network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address string, customResolver *net.Resolver) error {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
writeHooks := hooks.GetWriteHooks()
|
||||||
|
if len(writeHooks) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
host, _, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("split host and port: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver := customResolver
|
||||||
|
if resolver == nil {
|
||||||
|
resolver = net.DefaultResolver
|
||||||
|
}
|
||||||
|
|
||||||
|
ips, err := resolver.LookupIPAddr(ctx, host)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, ip := range ips {
|
||||||
|
prefix, err := util.GetPrefixFromIP(ip.IP)
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("convert IP %s to prefix: %w", ip.IP, err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, hook := range writeHooks {
|
||||||
|
if err := hook(connID, prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("executing dial hook for IP %s: %w", ip.IP, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
7
client/net/dialer_init_generic.go
Normal file
7
client/net/dialer_init_generic.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !linux && !windows
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
func (d *Dialer) init() {
|
||||||
|
// implemented on Linux, Android, and Windows only
|
||||||
|
}
|
||||||
5
client/net/dialer_init_windows.go
Normal file
5
client/net/dialer_init_windows.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
func (d *Dialer) init() {
|
||||||
|
d.Dialer.Control = applyUnicastIFToSocket
|
||||||
|
}
|
||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
|
||||||
|
envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CustomRoutingDisabled returns true if custom routing is disabled.
|
// CustomRoutingDisabled returns true if custom routing is disabled.
|
||||||
24
client/net/env_android.go
Normal file
24
client/net/env_android.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
// Init initializes the network environment for Android
|
||||||
|
func Init() {
|
||||||
|
// No initialization needed on Android
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
|
||||||
|
// Always returns true on Android since we cannot handle routes dynamically.
|
||||||
|
func AdvancedRouting() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetVPNInterfaceName is a no-op on Android
|
||||||
|
func SetVPNInterfaceName(name string) {
|
||||||
|
// No-op on Android - not needed for Android VPN service
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVPNInterfaceName returns empty string on Android
|
||||||
|
func GetVPNInterfaceName() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
23
client/net/env_generic.go
Normal file
23
client/net/env_generic.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
//go:build !linux && !windows && !android
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
// Init initializes the network environment (no-op on non-Linux/Windows platforms)
|
||||||
|
func Init() {
|
||||||
|
// No-op on non-Linux/Windows platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdvancedRouting returns false on non-Linux/Windows platforms
|
||||||
|
func AdvancedRouting() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetVPNInterfaceName is a no-op on non-Windows platforms
|
||||||
|
func SetVPNInterfaceName(name string) {
|
||||||
|
// No-op on non-Windows platforms
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVPNInterfaceName returns empty string on non-Windows platforms
|
||||||
|
func GetVPNInterfaceName() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
const (
|
const (
|
||||||
// these have the same effect, skip socket env supported for backward compatibility
|
// these have the same effect, skip socket env supported for backward compatibility
|
||||||
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
|
||||||
envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var advancedRoutingSupported bool
|
var advancedRoutingSupported bool
|
||||||
@@ -27,6 +26,7 @@ func Init() {
|
|||||||
advancedRoutingSupported = checkAdvancedRoutingSupport()
|
advancedRoutingSupported = checkAdvancedRoutingSupport()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
|
||||||
func AdvancedRouting() bool {
|
func AdvancedRouting() bool {
|
||||||
return advancedRoutingSupported
|
return advancedRoutingSupported
|
||||||
}
|
}
|
||||||
@@ -73,7 +73,7 @@ func checkAdvancedRoutingSupport() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CheckFwmarkSupport() bool {
|
func CheckFwmarkSupport() bool {
|
||||||
// temporarily enable advanced routing to check fwmarks are supported
|
// temporarily enable advanced routing to check if fwmarks are supported
|
||||||
old := advancedRoutingSupported
|
old := advancedRoutingSupported
|
||||||
advancedRoutingSupported = true
|
advancedRoutingSupported = true
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -129,3 +129,13 @@ func CheckRuleOperationsSupport() bool {
|
|||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetVPNInterfaceName is a no-op on Linux
|
||||||
|
func SetVPNInterfaceName(name string) {
|
||||||
|
// No-op on Linux - not needed for fwmark-based routing
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVPNInterfaceName returns empty string on Linux
|
||||||
|
func GetVPNInterfaceName() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
67
client/net/env_windows.go
Normal file
67
client/net/env_windows.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
vpnInterfaceName string
|
||||||
|
vpnInitMutex sync.RWMutex
|
||||||
|
|
||||||
|
advancedRoutingSupported bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func Init() {
|
||||||
|
advancedRoutingSupported = checkAdvancedRoutingSupport()
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkAdvancedRoutingSupport() bool {
|
||||||
|
var err error
|
||||||
|
var legacyRouting bool
|
||||||
|
if val := os.Getenv(envUseLegacyRouting); val != "" {
|
||||||
|
legacyRouting, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if legacyRouting || netstack.IsEnabled() {
|
||||||
|
log.Info("advanced routing has been requested to be disabled")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("system supports advanced routing")
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
|
||||||
|
func AdvancedRouting() bool {
|
||||||
|
return advancedRoutingSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetVPNInterfaceName returns the stored VPN interface name
|
||||||
|
func GetVPNInterfaceName() string {
|
||||||
|
vpnInitMutex.RLock()
|
||||||
|
defer vpnInitMutex.RUnlock()
|
||||||
|
return vpnInterfaceName
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetVPNInterfaceName sets the VPN interface name for lazy initialization
|
||||||
|
func SetVPNInterfaceName(name string) {
|
||||||
|
vpnInitMutex.Lock()
|
||||||
|
defer vpnInitMutex.Unlock()
|
||||||
|
vpnInterfaceName = name
|
||||||
|
|
||||||
|
if name != "" {
|
||||||
|
log.Infof("VPN interface name set to %s for route exclusion", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
93
client/net/hooks/hooks.go
Normal file
93
client/net/hooks/hooks.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package hooks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionID provides a globally unique identifier for network connections.
|
||||||
|
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
|
||||||
|
type ConnectionID string
|
||||||
|
|
||||||
|
// GenerateConnID generates a unique identifier for each connection.
|
||||||
|
func GenerateConnID() ConnectionID {
|
||||||
|
return ConnectionID(uuid.NewString())
|
||||||
|
}
|
||||||
|
|
||||||
|
type WriteHookFunc func(connID ConnectionID, prefix netip.Prefix) error
|
||||||
|
type CloseHookFunc func(connID ConnectionID) error
|
||||||
|
type AddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
|
||||||
|
|
||||||
|
var (
|
||||||
|
hooksMutex sync.RWMutex
|
||||||
|
|
||||||
|
writeHooks []WriteHookFunc
|
||||||
|
closeHooks []CloseHookFunc
|
||||||
|
addressRemoveHooks []AddressRemoveHookFunc
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddWriteHook allows adding a new hook to be executed before writing/dialing.
|
||||||
|
func AddWriteHook(hook WriteHookFunc) {
|
||||||
|
hooksMutex.Lock()
|
||||||
|
defer hooksMutex.Unlock()
|
||||||
|
writeHooks = append(writeHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCloseHook allows adding a new hook to be executed on connection close.
|
||||||
|
func AddCloseHook(hook CloseHookFunc) {
|
||||||
|
hooksMutex.Lock()
|
||||||
|
defer hooksMutex.Unlock()
|
||||||
|
closeHooks = append(closeHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveWriteHooks removes all write hooks.
|
||||||
|
func RemoveWriteHooks() {
|
||||||
|
hooksMutex.Lock()
|
||||||
|
defer hooksMutex.Unlock()
|
||||||
|
writeHooks = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveCloseHooks removes all close hooks.
|
||||||
|
func RemoveCloseHooks() {
|
||||||
|
hooksMutex.Lock()
|
||||||
|
defer hooksMutex.Unlock()
|
||||||
|
closeHooks = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddAddressRemoveHook allows adding a new hook to be executed when an address is removed.
|
||||||
|
func AddAddressRemoveHook(hook AddressRemoveHookFunc) {
|
||||||
|
hooksMutex.Lock()
|
||||||
|
defer hooksMutex.Unlock()
|
||||||
|
addressRemoveHooks = append(addressRemoveHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAddressRemoveHooks removes all listener address hooks.
|
||||||
|
func RemoveAddressRemoveHooks() {
|
||||||
|
hooksMutex.Lock()
|
||||||
|
defer hooksMutex.Unlock()
|
||||||
|
addressRemoveHooks = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWriteHooks returns a copy of the current write hooks.
|
||||||
|
func GetWriteHooks() []WriteHookFunc {
|
||||||
|
hooksMutex.RLock()
|
||||||
|
defer hooksMutex.RUnlock()
|
||||||
|
return slices.Clone(writeHooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCloseHooks returns a copy of the current close hooks.
|
||||||
|
func GetCloseHooks() []CloseHookFunc {
|
||||||
|
hooksMutex.RLock()
|
||||||
|
defer hooksMutex.RUnlock()
|
||||||
|
return slices.Clone(closeHooks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAddressRemoveHooks returns a copy of the current listener address remove hooks.
|
||||||
|
func GetAddressRemoveHooks() []AddressRemoveHookFunc {
|
||||||
|
hooksMutex.RLock()
|
||||||
|
defer hooksMutex.RUnlock()
|
||||||
|
return slices.Clone(addressRemoveHooks)
|
||||||
|
}
|
||||||
47
client/net/listen.go
Normal file
47
client/net/listen.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pion/transport/v3"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ListenUDP listens on the network address and returns a transport.UDPConn
|
||||||
|
// which includes support for write and close hooks.
|
||||||
|
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
|
||||||
|
if CustomRoutingDisabled() {
|
||||||
|
return net.ListenUDP(network, laddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listen UDP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch c := conn.(type) {
|
||||||
|
case *net.UDPConn:
|
||||||
|
// Advanced routing: plain connection
|
||||||
|
return c, nil
|
||||||
|
case *PacketConn:
|
||||||
|
// Legacy routing: wrapped connection for hooks
|
||||||
|
udpConn, ok := c.PacketConn.(*net.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
if err := c.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("expected UDPConn, got %T", c.PacketConn)
|
||||||
|
}
|
||||||
|
return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close connection: %v", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unexpected connection type: %T", conn)
|
||||||
|
}
|
||||||
@@ -7,14 +7,12 @@ import (
|
|||||||
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
|
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
|
||||||
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
|
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
|
||||||
type ListenerConfig struct {
|
type ListenerConfig struct {
|
||||||
*net.ListenConfig
|
net.ListenConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewListener creates a new ListenerConfig instance.
|
// NewListener creates a new ListenerConfig instance.
|
||||||
func NewListener() *ListenerConfig {
|
func NewListener() *ListenerConfig {
|
||||||
listener := &ListenerConfig{
|
listener := &ListenerConfig{}
|
||||||
ListenConfig: &net.ListenConfig{},
|
|
||||||
}
|
|
||||||
listener.init()
|
listener.init()
|
||||||
|
|
||||||
return listener
|
return listener
|
||||||
7
client/net/listener_init_generic.go
Normal file
7
client/net/listener_init_generic.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !linux && !windows
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
func (l *ListenerConfig) init() {
|
||||||
|
// implemented on Linux, Android, and Windows only
|
||||||
|
}
|
||||||
8
client/net/listener_init_windows.go
Normal file
8
client/net/listener_init_windows.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
func (l *ListenerConfig) init() {
|
||||||
|
// TODO: this will select a single source interface, but for UDP we can have various source interfaces and IP addresses.
|
||||||
|
// For now we stick to the one that matches the request IP address, which can be the unspecified IP. In this case
|
||||||
|
// the interface will be selected that serves the default route.
|
||||||
|
l.ListenConfig.Control = applyUnicastIFToSocket
|
||||||
|
}
|
||||||
153
client/net/listener_listen.go
Normal file
153
client/net/listener_listen.go
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
|
"github.com/netbirdio/netbird/client/net/hooks"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ListenPacket listens on the network address and returns a PacketConn
|
||||||
|
// which includes support for write hooks.
|
||||||
|
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
|
||||||
|
if CustomRoutingDisabled() || AdvancedRouting() {
|
||||||
|
return l.ListenConfig.ListenPacket(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("listen packet: %w", err)
|
||||||
|
}
|
||||||
|
connID := hooks.GenerateConnID()
|
||||||
|
|
||||||
|
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
|
||||||
|
type PacketConn struct {
|
||||||
|
net.PacketConn
|
||||||
|
ID hooks.ConnectionID
|
||||||
|
seenAddrs *sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
||||||
|
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||||
|
if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil {
|
||||||
|
log.Errorf("Failed to call write hooks: %v", err)
|
||||||
|
}
|
||||||
|
return c.PacketConn.WriteTo(b, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
|
||||||
|
func (c *PacketConn) Close() error {
|
||||||
|
defer c.seenAddrs.Clear()
|
||||||
|
return closeConn(c.ID, c.PacketConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
|
||||||
|
type UDPConn struct {
|
||||||
|
*net.UDPConn
|
||||||
|
ID hooks.ConnectionID
|
||||||
|
seenAddrs *sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
||||||
|
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||||
|
if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil {
|
||||||
|
log.Errorf("Failed to call write hooks: %v", err)
|
||||||
|
}
|
||||||
|
return c.UDPConn.WriteTo(b, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
|
||||||
|
func (c *UDPConn) Close() error {
|
||||||
|
defer c.seenAddrs.Clear()
|
||||||
|
return closeConn(c.ID, c.UDPConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
|
||||||
|
func (c *PacketConn) RemoveAddress(addr string) {
|
||||||
|
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ipStr, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Error splitting IP address and port: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ipAddr, err := netip.ParseAddr(ipStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Error parsing IP address %s: %v", ipStr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := netip.PrefixFrom(ipAddr.Unmap(), ipAddr.BitLen())
|
||||||
|
|
||||||
|
addressRemoveHooks := hooks.GetAddressRemoveHooks()
|
||||||
|
if len(addressRemoveHooks) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, hook := range addressRemoveHooks {
|
||||||
|
if err := hook(c.ID, prefix); err != nil {
|
||||||
|
log.Errorf("Error executing listener address remove hook: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapPacketConn wraps an existing net.PacketConn with nbnet hook functionality
|
||||||
|
func WrapPacketConn(conn net.PacketConn) net.PacketConn {
|
||||||
|
if AdvancedRouting() {
|
||||||
|
// hooks not required for advanced routing
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
return &PacketConn{
|
||||||
|
PacketConn: conn,
|
||||||
|
ID: hooks.GenerateConnID(),
|
||||||
|
seenAddrs: &sync.Map{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func callWriteHooks(id hooks.ConnectionID, seenAddrs *sync.Map, addr net.Addr) error {
|
||||||
|
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); loaded {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
writeHooks := hooks.GetWriteHooks()
|
||||||
|
if len(writeHooks) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
udpAddr, ok := addr.(*net.UDPAddr)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("expected *net.UDPAddr for packet connection, got %T", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix, err := util.GetPrefixFromIP(udpAddr.IP)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("convert UDP IP %s to prefix: %w", udpAddr.IP, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Listener resolved IP for %s: %s", addr, prefix)
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, hook := range writeHooks {
|
||||||
|
if err := hook(id, prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("execute write hook: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
@@ -5,8 +5,6 @@ import (
|
|||||||
"math/big"
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -44,18 +42,6 @@ func IsDataPlaneMark(fwmark uint32) bool {
|
|||||||
return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper
|
return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConnectionID provides a globally unique identifier for network connections.
|
|
||||||
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
|
|
||||||
type ConnectionID string
|
|
||||||
|
|
||||||
type AddHookFunc func(connID ConnectionID, IP net.IP) error
|
|
||||||
type RemoveHookFunc func(connID ConnectionID) error
|
|
||||||
|
|
||||||
// GenerateConnID generates a unique identifier for each connection.
|
|
||||||
func GenerateConnID() ConnectionID {
|
|
||||||
return ConnectionID(uuid.NewString())
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
|
func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
|
||||||
var endIP net.IP
|
var endIP net.IP
|
||||||
addr := network.Addr().AsSlice()
|
addr := network.Addr().AsSlice()
|
||||||
284
client/net/net_windows.go
Normal file
284
client/net/net_windows.go
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options
|
||||||
|
IpUnicastIf = 31
|
||||||
|
Ipv6UnicastIf = 31
|
||||||
|
|
||||||
|
// https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ipv6-socket-options
|
||||||
|
Ipv6V6only = 27
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetBestInterfaceFunc is set at runtime to avoid import cycle
|
||||||
|
var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error)
|
||||||
|
|
||||||
|
// nativeToBigEndian converts a uint32 from native byte order to big-endian
|
||||||
|
func nativeToBigEndian(v uint32) uint32 {
|
||||||
|
return (v&0xff)<<24 | (v&0xff00)<<8 | (v&0xff0000)>>8 | (v&0xff000000)>>24
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseDestinationAddress parses the destination address from various formats
|
||||||
|
func parseDestinationAddress(network, address string) (netip.Addr, error) {
|
||||||
|
if address == "" {
|
||||||
|
if strings.HasSuffix(network, "6") {
|
||||||
|
return netip.IPv6Unspecified(), nil
|
||||||
|
}
|
||||||
|
return netip.IPv4Unspecified(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if addrPort, err := netip.ParseAddrPort(address); err == nil {
|
||||||
|
return addrPort.Addr(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dest, err := netip.ParseAddr(address); err == nil {
|
||||||
|
return dest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
host, _, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
// No port, treat whole string as host
|
||||||
|
host = address
|
||||||
|
}
|
||||||
|
|
||||||
|
if host == "" {
|
||||||
|
if strings.HasSuffix(network, "6") {
|
||||||
|
return netip.IPv6Unspecified(), nil
|
||||||
|
}
|
||||||
|
return netip.IPv4Unspecified(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
|
||||||
|
if err != nil || len(ips) == 0 {
|
||||||
|
return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dest, ok := netip.AddrFromSlice(ips[0].IP)
|
||||||
|
if !ok {
|
||||||
|
return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ips[0].Zone != "" {
|
||||||
|
dest = dest.WithZone(ips[0].Zone)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getInterfaceFromZone(zone string) *net.Interface {
|
||||||
|
if zone == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
idx, err := strconv.Atoi(zone)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("invalid zone format for Windows (expected numeric): %s", zone)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
iface, err := net.InterfaceByIndex(idx)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to get interface by index %d from zone: %v", idx, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return iface
|
||||||
|
}
|
||||||
|
|
||||||
|
type interfaceSelection struct {
|
||||||
|
iface4 *net.Interface
|
||||||
|
iface6 *net.Interface
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectInterfaceForZone(dest netip.Addr, zone string) *interfaceSelection {
|
||||||
|
iface := getInterfaceFromZone(zone)
|
||||||
|
if iface == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dest.Is6() {
|
||||||
|
return &interfaceSelection{iface6: iface}
|
||||||
|
}
|
||||||
|
return &interfaceSelection{iface4: iface}
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectInterfaceForUnspecified() (*interfaceSelection, error) {
|
||||||
|
if GetBestInterfaceFunc == nil {
|
||||||
|
return nil, errors.New("GetBestInterfaceFunc not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
var result interfaceSelection
|
||||||
|
vpnIfaceName := GetVPNInterfaceName()
|
||||||
|
|
||||||
|
if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpnIfaceName); err == nil {
|
||||||
|
result.iface4 = iface4
|
||||||
|
} else {
|
||||||
|
log.Debugf("No IPv4 default route found: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpnIfaceName); err == nil {
|
||||||
|
result.iface6 = iface6
|
||||||
|
} else {
|
||||||
|
log.Debugf("No IPv6 default route found: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.iface4 == nil && result.iface6 == nil {
|
||||||
|
return nil, errors.New("no default routes found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectInterface(dest netip.Addr) (*interfaceSelection, error) {
|
||||||
|
if zone := dest.Zone(); zone != "" {
|
||||||
|
if selection := selectInterfaceForZone(dest, zone); selection != nil {
|
||||||
|
return selection, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if dest.IsUnspecified() {
|
||||||
|
return selectInterfaceForUnspecified()
|
||||||
|
}
|
||||||
|
|
||||||
|
if GetBestInterfaceFunc == nil {
|
||||||
|
return nil, errors.New("GetBestInterfaceFunc not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("find route for %s: %w", dest, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dest.Is6() {
|
||||||
|
return &interfaceSelection{iface6: iface}, nil
|
||||||
|
}
|
||||||
|
return &interfaceSelection{iface4: iface}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setIPv4UnicastIF(fd uintptr, iface *net.Interface) error {
|
||||||
|
ifaceIndexBE := nativeToBigEndian(uint32(iface.Index))
|
||||||
|
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IpUnicastIf, int(ifaceIndexBE)); err != nil {
|
||||||
|
return fmt.Errorf("set IP_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setIPv6UnicastIF(fd uintptr, iface *net.Interface) error {
|
||||||
|
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6UnicastIf, iface.Index); err != nil {
|
||||||
|
return fmt.Errorf("set IPV6_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setUnicastIf(fd uintptr, network string, selection *interfaceSelection, address string) error {
|
||||||
|
// The Go runtime always passes specific network types to Control (udp4, udp6, tcp4, tcp6, etc.)
|
||||||
|
// Never generic ones (udp, tcp, ip)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasSuffix(network, "4"):
|
||||||
|
// IPv4-only socket (udp4, tcp4, ip4)
|
||||||
|
return setUnicastIfIPv4(fd, network, selection, address)
|
||||||
|
|
||||||
|
case strings.HasSuffix(network, "6"):
|
||||||
|
// IPv6 socket (udp6, tcp6, ip6) - could be dual-stack or IPv6-only
|
||||||
|
return setUnicastIfIPv6(fd, network, selection, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shouldn't reach here based on Go's documented behavior
|
||||||
|
return fmt.Errorf("unexpected network type: %s", network)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setUnicastIfIPv4(fd uintptr, network string, selection *interfaceSelection, address string) error {
|
||||||
|
if selection.iface4 == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := setIPv4UnicastIF(fd, selection.iface4); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setUnicastIfIPv6(fd uintptr, network string, selection *interfaceSelection, address string) error {
|
||||||
|
isDualStack := checkDualStack(fd)
|
||||||
|
|
||||||
|
// For dual-stack sockets, also set the IPv4 option
|
||||||
|
if isDualStack && selection.iface4 != nil {
|
||||||
|
if err := setIPv4UnicastIF(fd, selection.iface4); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s (dual-stack)", selection.iface4.Index, selection.iface4.Name, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
if selection.iface6 == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := setIPv6UnicastIF(fd, selection.iface6); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Set IPV6_UNICAST_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkDualStack(fd uintptr) bool {
|
||||||
|
var v6Only int
|
||||||
|
v6OnlyLen := int32(unsafe.Sizeof(v6Only))
|
||||||
|
err := windows.Getsockopt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6V6only, (*byte)(unsafe.Pointer(&v6Only)), &v6OnlyLen)
|
||||||
|
return err == nil && v6Only == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyUnicastIFToSocket applies IpUnicastIf to a socket based on the destination address
|
||||||
|
func applyUnicastIFToSocket(network string, address string, c syscall.RawConn) error {
|
||||||
|
if !AdvancedRouting() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dest, err := parseDestinationAddress(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dest = dest.Unmap()
|
||||||
|
|
||||||
|
if !dest.IsValid() {
|
||||||
|
return fmt.Errorf("invalid destination address for %s", address)
|
||||||
|
}
|
||||||
|
|
||||||
|
selection, err := selectInterface(dest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var controlErr error
|
||||||
|
err = c.Control(func(fd uintptr) {
|
||||||
|
controlErr = setUnicastIf(fd, network, selection, address)
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("control: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return controlErr
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user