mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
Compare commits
17 Commits
fix/androi
...
feature/de
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b2d7121695 | ||
|
|
3288c4414f | ||
|
|
f3b0439211 | ||
|
|
ae801d77fb | ||
|
|
bf83549db2 | ||
|
|
804a3871fe | ||
|
|
64d1edce27 | ||
|
|
bf0698e5aa | ||
|
|
fc15625963 | ||
|
|
a75dde33b9 | ||
|
|
bb46e438aa | ||
|
|
11ba253ffb | ||
|
|
14fe7c29cb | ||
|
|
158f3aceff | ||
|
|
bfa776c155 | ||
|
|
885b5c68ad | ||
|
|
b1ebac795d |
@@ -4,12 +4,15 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/connectivity"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
@@ -17,6 +20,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
|
||||||
|
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
|
||||||
|
|
||||||
// Backoff returns a backoff configuration for gRPC calls
|
// Backoff returns a backoff configuration for gRPC calls
|
||||||
func Backoff(ctx context.Context) backoff.BackOff {
|
func Backoff(ctx context.Context) backoff.BackOff {
|
||||||
b := backoff.NewExponentialBackOff()
|
b := backoff.NewExponentialBackOff()
|
||||||
@@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
return backoff.WithContext(b, ctx)
|
return backoff.WithContext(b, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// waitForConnectionReady blocks until the connection becomes ready or fails.
|
||||||
|
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
|
||||||
|
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
|
||||||
|
conn.Connect()
|
||||||
|
|
||||||
|
state := conn.GetState()
|
||||||
|
for state != connectivity.Ready && state != connectivity.Shutdown {
|
||||||
|
if !conn.WaitForStateChange(ctx, state) {
|
||||||
|
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
|
||||||
|
}
|
||||||
|
state = conn.GetState()
|
||||||
|
}
|
||||||
|
|
||||||
|
if state == connectivity.Shutdown {
|
||||||
|
return ErrConnectionShutdown
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||||
@@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
conn, err := grpc.NewClient(
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := grpc.DialContext(
|
|
||||||
connCtx,
|
|
||||||
addr,
|
addr,
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(tlsEnabled, component),
|
WithCustomDialer(tlsEnabled, component),
|
||||||
grpc.WithBlock(),
|
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("DialContext error: %v", err)
|
return nil, fmt.Errorf("new client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := waitForConnectionReady(ctx, conn); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
func WithCustomDialer(_ bool, _ string) grpc.DialOption {
|
||||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
if runtime.GOOS == "linux" {
|
if runtime.GOOS == "linux" {
|
||||||
currentUser, err := user.Current()
|
currentUser, err := user.Current()
|
||||||
@@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
|||||||
|
|
||||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to dial: %s", err)
|
|
||||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||||
}
|
}
|
||||||
return conn, nil
|
return conn, nil
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -34,7 +35,6 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}
|
}
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
if c.engine != nil && c.engine.wgInterface != nil {
|
engine := c.engine
|
||||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
|
c.engine = nil
|
||||||
if err := c.engine.Stop(); err != nil {
|
c.engineMutex.Unlock()
|
||||||
|
|
||||||
|
if engine != nil && engine.wgInterface != nil {
|
||||||
|
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||||
|
if err := engine.Stop(); err != nil {
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
log.Errorf("Failed to stop engine: %v", err)
|
||||||
}
|
}
|
||||||
c.engine = nil
|
|
||||||
}
|
}
|
||||||
c.engineMutex.Unlock()
|
|
||||||
c.statusRecorder.ClientTeardown()
|
c.statusRecorder.ClientTeardown()
|
||||||
|
|
||||||
backOff.Reset()
|
backOff.Reset()
|
||||||
@@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) Stop() error {
|
func (c *ConnectClient) Stop() error {
|
||||||
if c == nil {
|
engine := c.Engine()
|
||||||
return nil
|
if engine != nil {
|
||||||
|
if err := engine.Stop(); err != nil {
|
||||||
|
return fmt.Errorf("stop engine: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
c.engineMutex.Lock()
|
|
||||||
defer c.engineMutex.Unlock()
|
|
||||||
|
|
||||||
if c.engine == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err := c.engine.Stop(); err != nil {
|
|
||||||
return fmt.Errorf("stop engine: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {
|
|||||||
|
|
||||||
// DefaultServer dns server object
|
// DefaultServer dns server object
|
||||||
type DefaultServer struct {
|
type DefaultServer struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
|
shutdownWg sync.WaitGroup
|
||||||
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
|
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
|
||||||
// This is different from ServiceEnable=false from management which completely disables the DNS service.
|
// This is different from ServiceEnable=false from management which completely disables the DNS service.
|
||||||
disableSys bool
|
disableSys bool
|
||||||
@@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
|||||||
// Stop stops the server
|
// Stop stops the server
|
||||||
func (s *DefaultServer) Stop() {
|
func (s *DefaultServer) Stop() {
|
||||||
s.ctxCancel()
|
s.ctxCancel()
|
||||||
|
s.shutdownWg.Wait()
|
||||||
|
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
@@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
|
|
||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
|
|
||||||
|
s.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
// persist dns state right away
|
defer s.shutdownWg.Done()
|
||||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||||
log.Errorf("Failed to persist dns state: %v", err)
|
log.Errorf("Failed to persist dns state: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -200,8 +200,10 @@ type Engine struct {
|
|||||||
flowManager nftypes.FlowManager
|
flowManager nftypes.FlowManager
|
||||||
|
|
||||||
// WireGuard interface monitor
|
// WireGuard interface monitor
|
||||||
wgIfaceMonitor *WGIfaceMonitor
|
wgIfaceMonitor *WGIfaceMonitor
|
||||||
wgIfaceMonitorWg sync.WaitGroup
|
|
||||||
|
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
|
||||||
|
shutdownWg sync.WaitGroup
|
||||||
|
|
||||||
// dns forwarder port
|
// dns forwarder port
|
||||||
dnsFwdPort uint16
|
dnsFwdPort uint16
|
||||||
@@ -326,10 +328,6 @@ func (e *Engine) Stop() error {
|
|||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
|
||||||
// Removing peers happens in the conn.Close() asynchronously
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
|
|
||||||
// stop flow manager after wg interface is gone
|
// stop flow manager after wg interface is gone
|
||||||
@@ -337,8 +335,6 @@ func (e *Engine) Stop() error {
|
|||||||
e.flowManager.Close()
|
e.flowManager.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("stopped Netbird Engine")
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -349,12 +345,52 @@ func (e *Engine) Stop() error {
|
|||||||
log.Errorf("failed to persist state: %v", err)
|
log.Errorf("failed to persist state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop WireGuard interface monitor and wait for it to exit
|
timeout := e.calculateShutdownTimeout()
|
||||||
e.wgIfaceMonitorWg.Wait()
|
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||||
|
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("stopped Netbird Engine")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||||
|
func (e *Engine) calculateShutdownTimeout() time.Duration {
|
||||||
|
peerCount := len(e.peerStore.PeersPubKey())
|
||||||
|
|
||||||
|
baseTimeout := 10 * time.Second
|
||||||
|
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
|
||||||
|
timeout := baseTimeout + perPeerTimeout
|
||||||
|
|
||||||
|
maxTimeout := 30 * time.Second
|
||||||
|
if timeout > maxTimeout {
|
||||||
|
timeout = maxTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
return timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
|
||||||
|
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||||
// Connections to remote peers are not established here.
|
// Connections to remote peers are not established here.
|
||||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||||
@@ -484,14 +520,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||||
e.wgIfaceMonitorWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer e.wgIfaceMonitorWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
|
|
||||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||||
e.restartEngine()
|
e.triggerClientRestart()
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
log.Warnf("WireGuard interface monitor: %s", err)
|
log.Warnf("WireGuard interface monitor: %s", err)
|
||||||
}
|
}
|
||||||
@@ -892,7 +928,9 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create ssh server: %w", err)
|
return fmt.Errorf("create ssh server: %w", err)
|
||||||
}
|
}
|
||||||
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer e.shutdownWg.Done()
|
||||||
// blocking
|
// blocking
|
||||||
err = e.sshServer.Start()
|
err = e.sshServer.Start()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -950,7 +988,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||||
func (e *Engine) receiveManagementEvents() {
|
func (e *Engine) receiveManagementEvents() {
|
||||||
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer e.shutdownWg.Done()
|
||||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get system info with checks: %v", err)
|
log.Warnf("failed to get system info with checks: %v", err)
|
||||||
@@ -1368,7 +1408,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
|
|
||||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||||
func (e *Engine) receiveSignalEvents() {
|
func (e *Engine) receiveSignalEvents() {
|
||||||
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer e.shutdownWg.Done()
|
||||||
// connect to a stream of messages coming from the signal server
|
// connect to a stream of messages coming from the signal server
|
||||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
@@ -1724,8 +1766,10 @@ func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// restartEngine restarts the engine by cancelling the client context
|
// triggerClientRestart triggers a full client restart by cancelling the client context.
|
||||||
func (e *Engine) restartEngine() {
|
// Note: This does NOT just restart the engine - it cancels the entire client context,
|
||||||
|
// which causes the connect client's retry loop to create a completely new engine.
|
||||||
|
func (e *Engine) triggerClientRestart() {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
@@ -1747,7 +1791,9 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
e.networkMonitor = networkmonitor.New()
|
e.networkMonitor = networkmonitor.New()
|
||||||
|
e.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
|
defer e.shutdownWg.Done()
|
||||||
if err := e.networkMonitor.Listen(e.ctx); err != nil {
|
if err := e.networkMonitor.Listen(e.ctx); err != nil {
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
log.Infof("network monitor stopped")
|
log.Infof("network monitor stopped")
|
||||||
@@ -1757,8 +1803,8 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("Network monitor: detected network change, restarting engine")
|
log.Infof("Network monitor: detected network change, triggering client restart")
|
||||||
e.restartEngine()
|
e.triggerClientRestart()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
// Manager handles netflow tracking and logging
|
// Manager handles netflow tracking and logging
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
|
shutdownWg sync.WaitGroup
|
||||||
logger nftypes.FlowLogger
|
logger nftypes.FlowLogger
|
||||||
flowConfig *nftypes.FlowConfig
|
flowConfig *nftypes.FlowConfig
|
||||||
conntrack nftypes.ConnTracker
|
conntrack nftypes.ConnTracker
|
||||||
@@ -105,8 +106,15 @@ func (m *Manager) resetClient() error {
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
m.cancel = cancel
|
m.cancel = cancel
|
||||||
|
|
||||||
go m.receiveACKs(ctx, flowClient)
|
m.shutdownWg.Add(2)
|
||||||
go m.startSender(ctx)
|
go func() {
|
||||||
|
defer m.shutdownWg.Done()
|
||||||
|
m.receiveACKs(ctx, flowClient)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer m.shutdownWg.Done()
|
||||||
|
m.startSender(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
|||||||
// Close cleans up all resources
|
// Close cleans up all resources
|
||||||
func (m *Manager) Close() {
|
func (m *Manager) Close() {
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
|
||||||
|
|
||||||
if err := m.disableFlow(); err != nil {
|
if err := m.disableFlow(); err != nil {
|
||||||
log.Warnf("failed to disable flow manager: %v", err)
|
log.Warnf("failed to disable flow manager: %v", err)
|
||||||
}
|
}
|
||||||
|
m.mux.Unlock()
|
||||||
|
|
||||||
|
m.shutdownWg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLogger returns the flow logger
|
// GetLogger returns the flow logger
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
|
//go:build dragonfly || freebsd || netbsd || openbsd
|
||||||
|
|
||||||
package networkmonitor
|
package networkmonitor
|
||||||
|
|
||||||
|
|||||||
344
client/internal/networkmonitor/check_change_darwin.go
Normal file
344
client/internal/networkmonitor/check_change_darwin.go
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
//go:build darwin && !ios
|
||||||
|
|
||||||
|
package networkmonitor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/route"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
)
|
||||||
|
|
||||||
|
// todo: refactor to not use static functions
|
||||||
|
|
||||||
|
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||||
|
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open routing socket: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := unix.Close(fd)
|
||||||
|
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||||
|
log.Warnf("Network monitor: failed to close routing socket: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
routeChanged := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
routeCheck(ctx, fd, nexthopv4, nexthopv6)
|
||||||
|
close(routeChanged)
|
||||||
|
}()
|
||||||
|
|
||||||
|
wakeUp := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wakeUpListen(ctx)
|
||||||
|
close(wakeUp)
|
||||||
|
}()
|
||||||
|
|
||||||
|
gatewayChanged := make(chan string)
|
||||||
|
go func() {
|
||||||
|
gatewayPoll(ctx, nexthopv4, nexthopv6, gatewayChanged)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-routeChanged:
|
||||||
|
log.Infof("route change detected via routing socket")
|
||||||
|
return nil
|
||||||
|
case <-wakeUp:
|
||||||
|
log.Infof("wakeup detected via sleep hash change")
|
||||||
|
return nil
|
||||||
|
case reason := <-gatewayChanged:
|
||||||
|
log.Infof("gateway change detected via polling: %s", reason)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func routeCheck(ctx context.Context, fd int, nexthopv4 systemops.Nexthop, nexthopv6 systemops.Nexthop) {
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 2048)
|
||||||
|
n, err := unix.Read(fd, buf)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||||
|
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if n < unix.SizeofRtMsghdr {
|
||||||
|
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||||
|
|
||||||
|
switch msg.Type {
|
||||||
|
// handle route changes
|
||||||
|
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||||
|
route, err := parseRouteMessage(buf[:n])
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if route.Dst.Bits() != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
intf := "<nil>"
|
||||||
|
if route.Interface != nil {
|
||||||
|
intf = route.Interface.Name
|
||||||
|
}
|
||||||
|
switch msg.Type {
|
||||||
|
case unix.RTM_ADD:
|
||||||
|
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||||
|
return
|
||||||
|
case unix.RTM_DELETE:
|
||||||
|
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||||
|
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msgs) != 1 {
|
||||||
|
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, ok := msgs[0].(*route.RouteMessage)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
return systemops.MsgToRoute(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func wakeUpListen(ctx context.Context) {
|
||||||
|
log.Infof("start to watch for system wakeups")
|
||||||
|
var (
|
||||||
|
initialHash uint32
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
// Keep retrying until initial sysctl succeeds or context is canceled
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
initialHash, err = readSleepTimeHash()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to detect initial sleep time: %v", err)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
|
||||||
|
return
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Infof("initial wakeup hash: %d", initialHash)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
ticker := time.NewTicker(5 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
lastCheck := time.Now()
|
||||||
|
const maxTickerDrift = 1 * time.Minute
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Info("context canceled, stopping wakeUpListen")
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-ticker.C:
|
||||||
|
now := time.Now()
|
||||||
|
elapsed := now.Sub(lastCheck)
|
||||||
|
|
||||||
|
// If more time passed than expected, system likely slept (informational only)
|
||||||
|
if elapsed > maxTickerDrift {
|
||||||
|
upOut, err := exec.Command("uptime").Output()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to run uptime command: %v", err)
|
||||||
|
upOut = []byte("unknown")
|
||||||
|
}
|
||||||
|
log.Infof("Time drift detected (potential wakeup): expected ~5s, actual %s, uptime: %s", elapsed, upOut)
|
||||||
|
|
||||||
|
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
|
||||||
|
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
|
||||||
|
if errV4 == nil {
|
||||||
|
log.Infof("Current IPv4 default gateway: %s via %s", currentV4.IP, currentV4.Intf.Name)
|
||||||
|
} else {
|
||||||
|
log.Debugf("No IPv4 default gateway: %v", errV4)
|
||||||
|
}
|
||||||
|
if errV6 == nil {
|
||||||
|
log.Infof("Current IPv6 default gateway: %s via %s", currentV6.IP, currentV6.Intf.Name)
|
||||||
|
} else {
|
||||||
|
log.Debugf("No IPv6 default gateway: %v", errV6)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newHash, err := readSleepTimeHash()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to read sleep time hash: %v", err)
|
||||||
|
lastCheck = now
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if newHash == initialHash {
|
||||||
|
log.Debugf("no wakeup detected (hash unchanged: %d, time drift: %s)", initialHash, elapsed)
|
||||||
|
lastCheck = now
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
upOut, err := exec.Command("uptime").Output()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to run uptime command: %v", err)
|
||||||
|
upOut = []byte("unknown")
|
||||||
|
}
|
||||||
|
log.Infof("Wakeup detected via hash change: %d -> %d, uptime: %s", initialHash, newHash, upOut)
|
||||||
|
|
||||||
|
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
|
||||||
|
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
|
||||||
|
if errV4 == nil {
|
||||||
|
log.Infof("Current IPv4 default gateway after wakeup: %s via %s", currentV4.IP, currentV4.Intf.Name)
|
||||||
|
} else {
|
||||||
|
log.Debugf("No IPv4 default gateway after wakeup: %v", errV4)
|
||||||
|
}
|
||||||
|
if errV6 == nil {
|
||||||
|
log.Infof("Current IPv6 default gateway after wakeup: %s via %s", currentV6.IP, currentV6.Intf.Name)
|
||||||
|
} else {
|
||||||
|
log.Debugf("No IPv6 default gateway after wakeup: %v", errV6)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readSleepTimeHash() (uint32, error) {
|
||||||
|
cmd := exec.Command("sysctl", "kern.sleeptime")
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to run sysctl: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h, err := hash(out)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to compute hash: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hash(data []byte) (uint32, error) {
|
||||||
|
hasher := fnv.New32a()
|
||||||
|
if _, err := hasher.Write(data); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return hasher.Sum32(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// gatewayPoll polls the default gateway every 5 seconds to detect changes that might be missed by routing socket or wake-up detection.
|
||||||
|
func gatewayPoll(ctx context.Context, initialV4, initialV6 systemops.Nexthop, changed chan<- string) {
|
||||||
|
ticker := time.NewTicker(5 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
log.Infof("Gateway polling started - initial v4: %s via %v, v6: %s via %v",
|
||||||
|
initialV4.IP, initialV4.Intf, initialV6.IP, initialV6.Intf)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Debug("context canceled, stopping gateway polling")
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-ticker.C:
|
||||||
|
currentV4, errV4 := systemops.GetNextHop(netip.IPv4Unspecified())
|
||||||
|
currentV6, errV6 := systemops.GetNextHop(netip.IPv6Unspecified())
|
||||||
|
|
||||||
|
var reason string
|
||||||
|
|
||||||
|
if errV4 == nil && initialV4.IP.IsValid() {
|
||||||
|
if currentV4.IP.Compare(initialV4.IP) != 0 {
|
||||||
|
reason = fmt.Sprintf("IPv4 gateway changed from %s to %s", initialV4.IP, currentV4.IP)
|
||||||
|
log.Infof("Gateway poll detected change: %s", reason)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if initialV4.Intf != nil && currentV4.Intf != nil && currentV4.Intf.Name != initialV4.Intf.Name {
|
||||||
|
reason = fmt.Sprintf("IPv4 interface changed from %s to %s", initialV4.Intf.Name, currentV4.Intf.Name)
|
||||||
|
log.Infof("Gateway poll detected change: %s", reason)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if errV4 == nil && !initialV4.IP.IsValid() {
|
||||||
|
reason = "IPv4 gateway appeared"
|
||||||
|
log.Infof("Gateway poll detected change: %s (new: %s)", reason, currentV4.IP)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
} else if errV4 != nil && initialV4.IP.IsValid() {
|
||||||
|
reason = "IPv4 gateway disappeared"
|
||||||
|
log.Infof("Gateway poll detected change: %s", reason)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if errV6 == nil && initialV6.IP.IsValid() {
|
||||||
|
if currentV6.IP.Compare(initialV6.IP) != 0 {
|
||||||
|
reason = fmt.Sprintf("IPv6 gateway changed from %s to %s", initialV6.IP, currentV6.IP)
|
||||||
|
log.Infof("Gateway poll detected change: %s", reason)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if initialV6.Intf != nil && currentV6.Intf != nil && currentV6.Intf.Name != initialV6.Intf.Name {
|
||||||
|
reason = fmt.Sprintf("IPv6 interface changed from %s to %s", initialV6.Intf.Name, currentV6.Intf.Name)
|
||||||
|
log.Infof("Gateway poll detected change: %s", reason)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if errV6 == nil && !initialV6.IP.IsValid() {
|
||||||
|
reason = "IPv6 gateway appeared"
|
||||||
|
log.Infof("Gateway poll detected change: %s (new: %s)", reason, currentV6.IP)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
} else if errV6 != nil && initialV6.IP.IsValid() {
|
||||||
|
reason = "IPv6 gateway disappeared"
|
||||||
|
log.Infof("Gateway poll detected change: %s", reason)
|
||||||
|
changed <- reason
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Gateway poll: no change detected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
|||||||
event := make(chan struct{}, 1)
|
event := make(chan struct{}, 1)
|
||||||
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
|
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
|
||||||
|
|
||||||
|
log.Infof("start watching for network changes")
|
||||||
// debounce changes
|
// debounce changes
|
||||||
timer := time.NewTimer(0)
|
timer := time.NewTimer(0)
|
||||||
timer.Stop()
|
timer.Stop()
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ type SRWatcher struct {
|
|||||||
signalClient chNotifier
|
signalClient chNotifier
|
||||||
relayManager chNotifier
|
relayManager chNotifier
|
||||||
|
|
||||||
listeners map[chan struct{}]struct{}
|
listeners map[chan struct{}]struct{}
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
shutdownWg sync.WaitGroup
|
||||||
iceConfig ice.Config
|
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||||
|
iceConfig ice.Config
|
||||||
cancelIceMonitor context.CancelFunc
|
cancelIceMonitor context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,7 +52,11 @@ func (w *SRWatcher) Start() {
|
|||||||
w.cancelIceMonitor = cancel
|
w.cancelIceMonitor = cancel
|
||||||
|
|
||||||
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
|
||||||
go iceMonitor.Start(ctx, w.onICEChanged)
|
w.shutdownWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer w.shutdownWg.Done()
|
||||||
|
iceMonitor.Start(ctx, w.onICEChanged)
|
||||||
|
}()
|
||||||
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
w.signalClient.SetOnReconnectedListener(w.onReconnected)
|
||||||
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
w.relayManager.SetOnReconnectedListener(w.onReconnected)
|
||||||
|
|
||||||
@@ -60,14 +64,16 @@ func (w *SRWatcher) Start() {
|
|||||||
|
|
||||||
func (w *SRWatcher) Close() {
|
func (w *SRWatcher) Close() {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
|
||||||
|
|
||||||
if w.cancelIceMonitor == nil {
|
if w.cancelIceMonitor == nil {
|
||||||
|
w.mu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.cancelIceMonitor()
|
w.cancelIceMonitor()
|
||||||
w.signalClient.SetOnReconnectedListener(nil)
|
w.signalClient.SetOnReconnectedListener(nil)
|
||||||
w.relayManager.SetOnReconnectedListener(nil)
|
w.relayManager.SetOnReconnectedListener(nil)
|
||||||
|
w.mu.Unlock()
|
||||||
|
|
||||||
|
w.shutdownWg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *SRWatcher) NewListener() chan struct{} {
|
func (w *SRWatcher) NewListener() chan struct{} {
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ type DefaultManager struct {
|
|||||||
ctx context.Context
|
ctx context.Context
|
||||||
stop context.CancelFunc
|
stop context.CancelFunc
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
|
shutdownWg sync.WaitGroup
|
||||||
clientNetworks map[route.HAUniqueID]*client.Watcher
|
clientNetworks map[route.HAUniqueID]*client.Watcher
|
||||||
routeSelector *routeselector.RouteSelector
|
routeSelector *routeselector.RouteSelector
|
||||||
serverRouter *server.Router
|
serverRouter *server.Router
|
||||||
@@ -273,6 +274,7 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
|
|||||||
// Stop stops the manager watchers and clean firewall rules
|
// Stop stops the manager watchers and clean firewall rules
|
||||||
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||||
m.stop()
|
m.stop()
|
||||||
|
m.shutdownWg.Wait()
|
||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
m.serverRouter.CleanUp()
|
m.serverRouter.CleanUp()
|
||||||
}
|
}
|
||||||
@@ -474,7 +476,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
|||||||
}
|
}
|
||||||
clientNetworkWatcher := client.NewWatcher(config)
|
clientNetworkWatcher := client.NewWatcher(config)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.Start()
|
m.shutdownWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer m.shutdownWg.Done()
|
||||||
|
clientNetworkWatcher.Start()
|
||||||
|
}()
|
||||||
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
|
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -516,7 +522,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
|||||||
}
|
}
|
||||||
clientNetworkWatcher = client.NewWatcher(config)
|
clientNetworkWatcher = client.NewWatcher(config)
|
||||||
m.clientNetworks[id] = clientNetworkWatcher
|
m.clientNetworks[id] = clientNetworkWatcher
|
||||||
go clientNetworkWatcher.Start()
|
m.shutdownWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer m.shutdownWg.Done()
|
||||||
|
clientNetworkWatcher.Start()
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
update := client.RoutesUpdate{
|
update := client.RoutesUpdate{
|
||||||
UpdateSerial: updateSerial,
|
UpdateSerial: updateSerial,
|
||||||
|
|||||||
@@ -105,11 +105,31 @@ func (r *SysOps) FlushMarkedRoutes() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
|
if prefix.IsSingleIP() {
|
||||||
|
log.Debugf("Adding single IP route: %s via %s", prefix, formatNexthop(nexthop))
|
||||||
|
}
|
||||||
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
return r.routeSocket(unix.RTM_ADD, prefix, nexthop)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
return r.routeSocket(unix.RTM_DELETE, prefix, nexthop)
|
if prefix.IsSingleIP() {
|
||||||
|
log.Debugf("Removing single IP route: %s via %s", prefix, formatNexthop(nexthop))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.routeSocket(unix.RTM_DELETE, prefix, nexthop); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix.IsSingleIP() {
|
||||||
|
log.Debugf("Route removal completed for %s, verifying...", prefix)
|
||||||
|
if exists := r.verifyRouteRemoved(prefix); exists {
|
||||||
|
log.Warnf("Route %s still exists in routing table after removal", prefix)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Verified route %s successfully removed", prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
|
func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) error {
|
||||||
@@ -276,3 +296,51 @@ func prefixToRouteNetmask(prefix netip.Prefix) (route.Addr, error) {
|
|||||||
|
|
||||||
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
|
return nil, fmt.Errorf("unknown IP version in prefix: %s", prefix.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// formatNexthop returns a string representation of the nexthop for logging.
|
||||||
|
func formatNexthop(nexthop Nexthop) string {
|
||||||
|
if nexthop.IP.IsValid() {
|
||||||
|
return nexthop.IP.String()
|
||||||
|
}
|
||||||
|
if nexthop.Intf != nil {
|
||||||
|
return nexthop.Intf.Name
|
||||||
|
}
|
||||||
|
return "direct"
|
||||||
|
}
|
||||||
|
|
||||||
|
// verifyRouteRemoved checks if a route still exists in the routing table.
|
||||||
|
func (r *SysOps) verifyRouteRemoved(prefix netip.Prefix) bool {
|
||||||
|
rib, err := retryFetchRIB()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Failed to fetch RIB for route verification: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Failed to parse RIB for route verification: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, msg := range msgs {
|
||||||
|
rtMsg, ok := msg.(*route.RouteMessage)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if rtMsg.Flags&routeProtoFlag == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
routeInfo, err := MsgToRoute(rtMsg)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if routeInfo.Dst == prefix {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,8 +17,7 @@ type Conn struct {
|
|||||||
ID hooks.ConnectionID
|
ID hooks.ConnectionID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection.
|
||||||
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
return closeConn(c.ID, c.Conn)
|
return closeConn(c.ID, c.Conn)
|
||||||
}
|
}
|
||||||
@@ -29,7 +28,7 @@ type TCPConn struct {
|
|||||||
ID hooks.ConnectionID
|
ID hooks.ConnectionID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
|
// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection.
|
||||||
func (c *TCPConn) Close() error {
|
func (c *TCPConn) Close() error {
|
||||||
return closeConn(c.ID, c.TCPConn)
|
return closeConn(c.ID, c.TCPConn)
|
||||||
}
|
}
|
||||||
@@ -37,13 +36,16 @@ func (c *TCPConn) Close() error {
|
|||||||
// closeConn is a helper function to close connections and execute close hooks.
|
// closeConn is a helper function to close connections and execute close hooks.
|
||||||
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
|
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
|
||||||
err := conn.Close()
|
err := conn.Close()
|
||||||
|
cleanupConnID(id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupConnID executes close hooks for a connection ID.
|
||||||
|
func cleanupConnID(id hooks.ConnectionID) {
|
||||||
closeHooks := hooks.GetCloseHooks()
|
closeHooks := hooks.GetCloseHooks()
|
||||||
for _, hook := range closeHooks {
|
for _, hook := range closeHooks {
|
||||||
if err := hook(id); err != nil {
|
if err := hook(id); err != nil {
|
||||||
log.Errorf("Error executing close hook: %v", err)
|
log.Errorf("Error executing close hook: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro
|
|||||||
}
|
}
|
||||||
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
|
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
log.Errorf("failed to close connection: %v", err)
|
log.Errorf("failed to close connection: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
|
|||||||
|
|
||||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
cleanupConnID(connID)
|
||||||
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str
|
|||||||
|
|
||||||
ips, err := resolver.LookupIPAddr(ctx, host)
|
ips, err := resolver.LookupIPAddr(ctx, host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
return fmt.Errorf("resolve address %s: %w", address, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|||||||
return c.PacketConn.WriteTo(b, addr)
|
return c.PacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
|
// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection.
|
||||||
func (c *PacketConn) Close() error {
|
func (c *PacketConn) Close() error {
|
||||||
defer c.seenAddrs.Clear()
|
defer c.seenAddrs.Clear()
|
||||||
return closeConn(c.ID, c.PacketConn)
|
return closeConn(c.ID, c.PacketConn)
|
||||||
@@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|||||||
return c.UDPConn.WriteTo(b, addr)
|
return c.UDPConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
|
// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection.
|
||||||
func (c *UDPConn) Close() error {
|
func (c *UDPConn) Close() error {
|
||||||
defer c.seenAddrs.Clear()
|
defer c.seenAddrs.Clear()
|
||||||
return closeConn(c.ID, c.UDPConn)
|
return closeConn(c.ID, c.UDPConn)
|
||||||
|
|||||||
@@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
|
|||||||
var err error
|
var err error
|
||||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
|
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("createConnection error: %v", err)
|
return fmt.Errorf("create connection: %w", err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
|
|||||||
var err error
|
var err error
|
||||||
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
|
conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("createConnection error: %v", err)
|
return fmt.Errorf("create connection: %w", err)
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user