diff --git a/client/embed/embed.go b/client/embed/embed.go index 9ded618c5..fe95b1942 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -134,10 +134,11 @@ func (c *Client) Start(startCtx context.Context) error { // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available - run := make(chan error, 1) + run := make(chan struct{}, 1) + clientErr := make(chan error, 1) go func() { if err := client.Run(run); err != nil { - run <- err + clientErr <- err } }() @@ -147,13 +148,9 @@ func (c *Client) Start(startCtx context.Context) error { return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err()) } return startCtx.Err() - case err := <-run: - if err != nil { - if stopErr := client.Stop(); stopErr != nil { - return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err) - } - return fmt.Errorf("startup: %w", err) - } + case err := <-clientErr: + return fmt.Errorf("startup: %w", err) + case <-run: } c.connect = client diff --git a/client/firewall/iface.go b/client/firewall/iface.go index d842abaa1..b83c5f912 100644 --- a/client/firewall/iface.go +++ b/client/firewall/iface.go @@ -4,12 +4,13 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { Name() string - Address() device.WGAddress + Address() wgaddr.Address IsUserspaceBind() bool SetFilter(device.PacketFilter) error GetDevice() *device.FilteredDevice diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 542a13c1f..652ab1b3e 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -13,7 +13,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -31,7 +31,7 @@ type Manager struct { // iFaceMapper defines subset methods of interface required for manager type iFaceMapper interface { Name() string - Address() iface.WGAddress + Address() wgaddr.Address IsUserspaceBind() bool } diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index 23926d059..af9f5dd23 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -10,15 +10,15 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) var ifaceMock = &iFaceMock{ NameFunc: func() string { return "lo" }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ IP: net.ParseIP("10.20.0.1"), Network: &net.IPNet{ IP: net.ParseIP("10.20.0.0"), @@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{ // iFaceMapper defines subset methods of interface required for manager type iFaceMock struct { NameFunc func() string - AddressFunc func() iface.WGAddress + AddressFunc func() wgaddr.Address } func (i *iFaceMock) Name() string { @@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string { panic("NameFunc is not set") } -func (i *iFaceMock) Address() iface.WGAddress { +func (i *iFaceMock) Address() wgaddr.Address { if i.AddressFunc != nil { return i.AddressFunc() } @@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) { NameFunc: func() string { return "lo" }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ IP: net.ParseIP("10.20.0.1"), Network: &net.IPNet{ IP: net.ParseIP("10.20.0.0"), @@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) { NameFunc: func() string { return "lo" }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ IP: net.ParseIP("10.20.0.1"), Network: &net.IPNet{ IP: net.ParseIP("10.20.0.0"), diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index 2a7120bbf..6ef159e01 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -4,21 +4,20 @@ import ( "fmt" "sync" - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) type InterfaceState struct { - NameStr string `json:"name"` - WGAddress iface.WGAddress `json:"wg_address"` - UserspaceBind bool `json:"userspace_bind"` + NameStr string `json:"name"` + WGAddress wgaddr.Address `json:"wg_address"` + UserspaceBind bool `json:"userspace_bind"` } func (i *InterfaceState) Name() string { return i.NameStr } -func (i *InterfaceState) Address() device.WGAddress { +func (i *InterfaceState) Address() wgaddr.Address { return i.WGAddress } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 475601d17..a5809471c 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -14,7 +14,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -29,7 +29,7 @@ const ( // iFaceMapper defines subset methods of interface required for manager type iFaceMapper interface { Name() string - Address() iface.WGAddress + Address() wgaddr.Address IsUserspaceBind() bool } diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 3f1a6e4b3..373743a08 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -16,15 +16,15 @@ import ( "golang.org/x/sys/unix" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) var ifaceMock = &iFaceMock{ NameFunc: func() string { return "lo" }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ IP: net.ParseIP("100.96.0.1"), Network: &net.IPNet{ IP: net.ParseIP("100.96.0.0"), @@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{ // iFaceMapper defines subset methods of interface required for manager type iFaceMock struct { NameFunc func() string - AddressFunc func() iface.WGAddress + AddressFunc func() wgaddr.Address } func (i *iFaceMock) Name() string { @@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string { panic("NameFunc is not set") } -func (i *iFaceMock) Address() iface.WGAddress { +func (i *iFaceMock) Address() wgaddr.Address { if i.AddressFunc != nil { return i.AddressFunc() } @@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) { NameFunc: func() string { return "lo" }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ IP: net.ParseIP("100.96.0.1"), Network: &net.IPNet{ IP: net.ParseIP("100.96.0.0"), diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go index facca1cec..f805623d6 100644 --- a/client/firewall/nftables/state_linux.go +++ b/client/firewall/nftables/state_linux.go @@ -3,21 +3,20 @@ package nftables import ( "fmt" - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) type InterfaceState struct { - NameStr string `json:"name"` - WGAddress iface.WGAddress `json:"wg_address"` - UserspaceBind bool `json:"userspace_bind"` + NameStr string `json:"name"` + WGAddress wgaddr.Address `json:"wg_address"` + UserspaceBind bool `json:"userspace_bind"` } func (i *InterfaceState) Name() string { return i.NameStr } -func (i *InterfaceState) Address() device.WGAddress { +func (i *InterfaceState) Address() wgaddr.Address { return i.WGAddress } diff --git a/client/firewall/uspfilter/common/iface.go b/client/firewall/uspfilter/common/iface.go index d44e79509..7296953db 100644 --- a/client/firewall/uspfilter/common/iface.go +++ b/client/firewall/uspfilter/common/iface.go @@ -3,14 +3,14 @@ package common import ( wgdevice "golang.zx2c4.com/wireguard/device" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { SetFilter(device.PacketFilter) error - Address() iface.WGAddress + Address() wgaddr.Address GetWGDevice() *wgdevice.Device GetDevice() *device.FilteredDevice } diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 49cc832e6..1a6566fa8 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -1,6 +1,7 @@ package conntrack import ( + "context" "fmt" "net/netip" "sync" @@ -44,8 +45,8 @@ type ICMPTracker struct { connections map[ICMPConnKey]*ICMPConnTrack timeout time.Duration cleanupTicker *time.Ticker + tickerCancel context.CancelFunc mutex sync.RWMutex - done chan struct{} flowLogger nftypes.FlowLogger } @@ -55,16 +56,18 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty timeout = DefaultICMPTimeout } + ctx, cancel := context.WithCancel(context.Background()) + tracker := &ICMPTracker{ logger: logger, connections: make(map[ICMPConnKey]*ICMPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(ICMPCleanupInterval), - done: make(chan struct{}), + tickerCancel: cancel, flowLogger: flowLogger, } - go tracker.cleanupRoutine() + go tracker.cleanupRoutine(ctx) return tracker } @@ -164,12 +167,14 @@ func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint return true } -func (t *ICMPTracker) cleanupRoutine() { +func (t *ICMPTracker) cleanupRoutine(ctx context.Context) { + defer t.tickerCancel() + for { select { case <-t.cleanupTicker.C: t.cleanup() - case <-t.done: + case <-ctx.Done(): return } } @@ -192,8 +197,7 @@ func (t *ICMPTracker) cleanup() { // Close stops the cleanup routine and releases resources func (t *ICMPTracker) Close() { - t.cleanupTicker.Stop() - close(t.done) + t.tickerCancel() t.mutex.Lock() t.connections = nil diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index b5e470bf9..a1f17966a 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -3,6 +3,7 @@ package conntrack // TODO: Send RST packets for invalid/timed-out connections import ( + "context" "net/netip" "sync" "sync/atomic" @@ -122,7 +123,7 @@ type TCPTracker struct { connections map[ConnKey]*TCPConnTrack mutex sync.RWMutex cleanupTicker *time.Ticker - done chan struct{} + tickerCancel context.CancelFunc timeout time.Duration flowLogger nftypes.FlowLogger } @@ -133,16 +134,18 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp timeout = DefaultTCPTimeout } + ctx, cancel := context.WithCancel(context.Background()) + tracker := &TCPTracker{ logger: logger, connections: make(map[ConnKey]*TCPConnTrack), cleanupTicker: time.NewTicker(TCPCleanupInterval), - done: make(chan struct{}), + tickerCancel: cancel, timeout: timeout, flowLogger: flowLogger, } - go tracker.cleanupRoutine() + go tracker.cleanupRoutine(ctx) return tracker } @@ -396,12 +399,14 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { return false } -func (t *TCPTracker) cleanupRoutine() { +func (t *TCPTracker) cleanupRoutine(ctx context.Context) { + defer t.cleanupTicker.Stop() + for { select { case <-t.cleanupTicker.C: t.cleanup() - case <-t.done: + case <-ctx.Done(): return } } @@ -444,8 +449,7 @@ func (t *TCPTracker) cleanup() { // Close stops the cleanup routine and releases resources func (t *TCPTracker) Close() { - t.cleanupTicker.Stop() - close(t.done) + t.tickerCancel() // Clean up all remaining IPs t.mutex.Lock() diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 94db24f5f..7ca493a9d 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -1,6 +1,7 @@ package conntrack import ( + "context" "net/netip" "sync" "time" @@ -31,8 +32,8 @@ type UDPTracker struct { connections map[ConnKey]*UDPConnTrack timeout time.Duration cleanupTicker *time.Ticker + tickerCancel context.CancelFunc mutex sync.RWMutex - done chan struct{} flowLogger nftypes.FlowLogger } @@ -42,16 +43,18 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp timeout = DefaultUDPTimeout } + ctx, cancel := context.WithCancel(context.Background()) + tracker := &UDPTracker{ logger: logger, connections: make(map[ConnKey]*UDPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(UDPCleanupInterval), - done: make(chan struct{}), + tickerCancel: cancel, flowLogger: flowLogger, } - go tracker.cleanupRoutine() + go tracker.cleanupRoutine(ctx) return tracker } @@ -140,12 +143,14 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort } // cleanupRoutine periodically removes stale connections -func (t *UDPTracker) cleanupRoutine() { +func (t *UDPTracker) cleanupRoutine(ctx context.Context) { + defer t.cleanupTicker.Stop() + for { select { case <-t.cleanupTicker.C: t.cleanup() - case <-t.done: + case <-ctx.Done(): return } } @@ -168,8 +173,7 @@ func (t *UDPTracker) cleanup() { // Close stops the cleanup routine and releases resources func (t *UDPTracker) Close() { - t.cleanupTicker.Stop() - close(t.done) + t.tickerCancel() t.mutex.Lock() t.connections = nil diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index 14b912908..7ad1e0e4b 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -1,6 +1,7 @@ package conntrack import ( + "context" "net/netip" "testing" "time" @@ -34,7 +35,7 @@ func TestNewUDPTracker(t *testing.T) { assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.cleanupTicker) - assert.NotNil(t, tracker.done) + assert.NotNil(t, tracker.tickerCancel) }) } } @@ -159,18 +160,21 @@ func TestUDPTracker_Cleanup(t *testing.T) { timeout := 50 * time.Millisecond cleanupInterval := 25 * time.Millisecond + ctx, tickerCancel := context.WithCancel(context.Background()) + defer tickerCancel() + // Create tracker with custom cleanup interval tracker := &UDPTracker{ connections: make(map[ConnKey]*UDPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(cleanupInterval), - done: make(chan struct{}), + tickerCancel: tickerCancel, logger: logger, flowLogger: flowLogger, } // Start cleanup routine - go tracker.cleanupRoutine() + go tracker.cleanupRoutine(ctx) // Add some connections connections := []struct { diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go index 890b7a30d..0715ddc41 100644 --- a/client/firewall/uspfilter/localip_test.go +++ b/client/firewall/uspfilter/localip_test.go @@ -7,19 +7,19 @@ import ( "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) func TestLocalIPManager(t *testing.T) { tests := []struct { name string - setupAddr iface.WGAddress + setupAddr wgaddr.Address testIP netip.Addr expected bool }{ { name: "Localhost range", - setupAddr: iface.WGAddress{ + setupAddr: wgaddr.Address{ IP: net.ParseIP("192.168.1.1"), Network: &net.IPNet{ IP: net.ParseIP("192.168.1.0"), @@ -31,7 +31,7 @@ func TestLocalIPManager(t *testing.T) { }, { name: "Localhost standard address", - setupAddr: iface.WGAddress{ + setupAddr: wgaddr.Address{ IP: net.ParseIP("192.168.1.1"), Network: &net.IPNet{ IP: net.ParseIP("192.168.1.0"), @@ -43,7 +43,7 @@ func TestLocalIPManager(t *testing.T) { }, { name: "Localhost range edge", - setupAddr: iface.WGAddress{ + setupAddr: wgaddr.Address{ IP: net.ParseIP("192.168.1.1"), Network: &net.IPNet{ IP: net.ParseIP("192.168.1.0"), @@ -55,7 +55,7 @@ func TestLocalIPManager(t *testing.T) { }, { name: "Local IP matches", - setupAddr: iface.WGAddress{ + setupAddr: wgaddr.Address{ IP: net.ParseIP("192.168.1.1"), Network: &net.IPNet{ IP: net.ParseIP("192.168.1.0"), @@ -67,7 +67,7 @@ func TestLocalIPManager(t *testing.T) { }, { name: "Local IP doesn't match", - setupAddr: iface.WGAddress{ + setupAddr: wgaddr.Address{ IP: net.ParseIP("192.168.1.1"), Network: &net.IPNet{ IP: net.ParseIP("192.168.1.0"), @@ -79,7 +79,7 @@ func TestLocalIPManager(t *testing.T) { }, { name: "IPv6 address", - setupAddr: iface.WGAddress{ + setupAddr: wgaddr.Address{ IP: net.ParseIP("fe80::1"), Network: &net.IPNet{ IP: net.ParseIP("fe80::"), @@ -96,7 +96,7 @@ func TestLocalIPManager(t *testing.T) { manager := newLocalIPManager() mock := &IFaceMock{ - AddressFunc: func() iface.WGAddress { + AddressFunc: func() wgaddr.Address { return tt.setupAddr }, } diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go index c6902dfea..ba97c2643 100644 --- a/client/firewall/uspfilter/uspfilter_filter_test.go +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -12,9 +12,9 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/mocks" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) func TestPeerACLFiltering(t *testing.T) { @@ -26,8 +26,8 @@ func TestPeerACLFiltering(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ IP: localIP, Network: wgNet, } @@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ IP: localIP, Network: wgNet, } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index bc1ce4398..e525b6246 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -18,8 +18,8 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/log" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/netflow" ) @@ -28,7 +28,7 @@ var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogg type IFaceMock struct { SetFilterFunc func(device.PacketFilter) error - AddressFunc func() iface.WGAddress + AddressFunc func() wgaddr.Address GetWGDeviceFunc func() *wgdevice.Device GetDeviceFunc func() *device.FilteredDevice } @@ -54,9 +54,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { return i.SetFilterFunc(iface) } -func (i *IFaceMock) Address() iface.WGAddress { +func (i *IFaceMock) Address() wgaddr.Address { if i.AddressFunc == nil { - return iface.WGAddress{} + return wgaddr.Address{} } return i.AddressFunc() } @@ -269,8 +269,8 @@ func TestManagerReset(t *testing.T) { func TestNotMatchByIP(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ IP: net.ParseIP("100.10.0.100"), Network: &net.IPNet{ IP: net.ParseIP("100.10.0.0"), diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 6897f04a1..66ec6a00d 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -13,6 +13,8 @@ import ( "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" wgConn "golang.zx2c4.com/wireguard/conn" + + "github.com/netbirdio/netbird/client/iface/wgaddr" ) type RecvMessage struct { @@ -51,9 +53,10 @@ type ICEBind struct { muUDPMux sync.Mutex udpMux *UniversalUDPMuxDefault + address wgaddr.Address } -func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { +func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind { b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, @@ -63,6 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { endpoints: make(map[netip.Addr]net.Conn), closedChan: make(chan struct{}), closed: true, + address: address, } rc := receiverCreator{ @@ -142,9 +146,10 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r s.udpMux = NewUniversalUDPMuxDefault( UniversalUDPMuxParams{ - UDPConn: conn, - Net: s.transportNet, - FilterFn: s.filterFn, + UDPConn: conn, + Net: s.transportNet, + FilterFn: s.filterFn, + WGAddress: s.address, }, ) return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { diff --git a/client/iface/bind/udp_mux_universal.go b/client/iface/bind/udp_mux_universal.go index ebbefe035..6f851393e 100644 --- a/client/iface/bind/udp_mux_universal.go +++ b/client/iface/bind/udp_mux_universal.go @@ -17,6 +17,8 @@ import ( "github.com/pion/logging" "github.com/pion/stun/v2" "github.com/pion/transport/v3" + + "github.com/netbirdio/netbird/client/iface/wgaddr" ) // FilterFn is a function that filters out candidates based on the address. @@ -41,6 +43,7 @@ type UniversalUDPMuxParams struct { XORMappedAddrCacheTTL time.Duration Net transport.Net FilterFn FilterFn + WGAddress wgaddr.Address } // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux @@ -64,6 +67,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef mux: m, logger: params.Logger, filterFn: params.FilterFn, + address: params.WGAddress, } // embed UDPMux @@ -118,6 +122,7 @@ type udpConn struct { filterFn FilterFn // TODO: reset cache on route changes addrCache sync.Map + address wgaddr.Address } func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { @@ -159,6 +164,11 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error { return nil } + if u.address.Network.Contains(a.AsSlice()) { + log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) + return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) + } + if isRouted, prefix, err := u.filterFn(a); err != nil { log.Errorf("Failed to check if address %s is routed: %v", addr, err) } else { diff --git a/client/iface/device.go b/client/iface/device.go index 86e9dab4b..81f2e0f47 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -9,13 +9,14 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create() (device.WGConfigurer, error) Up() (*bind.UniversalUDPMuxDefault, error) - UpdateAddr(address WGAddress) error - WgAddress() WGAddress + UpdateAddr(address wgaddr.Address) error + WgAddress() wgaddr.Address DeviceName() string Close() error FilteredDevice() *device.FilteredDevice diff --git a/client/iface/device/device_android.go b/client/iface/device/device_android.go index 55081e181..ab3e611e1 100644 --- a/client/iface/device/device_android.go +++ b/client/iface/device/device_android.go @@ -13,11 +13,12 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform type WGTunDevice struct { - address WGAddress + address wgaddr.Address port int key string mtu int @@ -31,7 +32,7 @@ type WGTunDevice struct { configurer WGConfigurer } -func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { +func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { return &WGTunDevice{ address: address, port: port, @@ -93,7 +94,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *WGTunDevice) UpdateAddr(addr WGAddress) error { +func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error { // todo implement return nil } @@ -123,7 +124,7 @@ func (t *WGTunDevice) DeviceName() string { return t.name } -func (t *WGTunDevice) WgAddress() WGAddress { +func (t *WGTunDevice) WgAddress() wgaddr.Address { return t.address } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index 1a5635ff2..01bfbf381 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -13,11 +13,12 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) type TunDevice struct { name string - address WGAddress + address wgaddr.Address port int key string mtu int @@ -29,7 +30,7 @@ type TunDevice struct { configurer WGConfigurer } -func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { +func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { return &TunDevice{ name: name, address: address, @@ -85,7 +86,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *TunDevice) UpdateAddr(address WGAddress) error { +func (t *TunDevice) UpdateAddr(address wgaddr.Address) error { t.address = address return t.assignAddr() } @@ -106,7 +107,7 @@ func (t *TunDevice) Close() error { return nil } -func (t *TunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() wgaddr.Address { return t.address } diff --git a/client/iface/device/device_ios.go b/client/iface/device/device_ios.go index b106d475c..56d44d68e 100644 --- a/client/iface/device/device_ios.go +++ b/client/iface/device/device_ios.go @@ -14,11 +14,12 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) type TunDevice struct { name string - address WGAddress + address wgaddr.Address port int key string iceBind *bind.ICEBind @@ -30,7 +31,7 @@ type TunDevice struct { configurer WGConfigurer } -func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { +func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { return &TunDevice{ name: name, address: address, @@ -120,11 +121,11 @@ func (t *TunDevice) Close() error { return nil } -func (t *TunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() wgaddr.Address { return t.address } -func (t *TunDevice) UpdateAddr(addr WGAddress) error { +func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error { // todo implement return nil } diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index fe1d1147f..988ed1b39 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -14,12 +14,13 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/sharedsock" ) type TunKernelDevice struct { name string - address WGAddress + address wgaddr.Address wgPort int key string mtu int @@ -34,7 +35,7 @@ type TunKernelDevice struct { filterFn bind.FilterFn } -func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { +func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { ctx, cancel := context.WithCancel(context.Background()) return &TunKernelDevice{ ctx: ctx, @@ -99,9 +100,10 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return nil, err } bindParams := bind.UniversalUDPMuxParams{ - UDPConn: rawSock, - Net: t.transportNet, - FilterFn: t.filterFn, + UDPConn: rawSock, + Net: t.transportNet, + FilterFn: t.filterFn, + WGAddress: t.address, } mux := bind.NewUniversalUDPMuxDefault(bindParams) go mux.ReadFromConn(t.ctx) @@ -112,7 +114,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return t.udpMux, nil } -func (t *TunKernelDevice) UpdateAddr(address WGAddress) error { +func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error { t.address = address return t.assignAddr() } @@ -145,7 +147,7 @@ func (t *TunKernelDevice) Close() error { return closErr } -func (t *TunKernelDevice) WgAddress() WGAddress { +func (t *TunKernelDevice) WgAddress() wgaddr.Address { return t.address } diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index 0cb02fd19..d3c92235e 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -13,12 +13,13 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" nbnet "github.com/netbirdio/netbird/util/net" ) type TunNetstackDevice struct { name string - address WGAddress + address wgaddr.Address port int key string mtu int @@ -34,7 +35,7 @@ type TunNetstackDevice struct { net *netstack.Net } -func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { +func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { return &TunNetstackDevice{ name: name, address: address, @@ -97,7 +98,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *TunNetstackDevice) UpdateAddr(WGAddress) error { +func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error { return nil } @@ -116,7 +117,7 @@ func (t *TunNetstackDevice) Close() error { return nil } -func (t *TunNetstackDevice) WgAddress() WGAddress { +func (t *TunNetstackDevice) WgAddress() wgaddr.Address { return t.address } diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 07570617a..c45ae9676 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -12,11 +12,12 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) type USPDevice struct { name string - address WGAddress + address wgaddr.Address port int key string mtu int @@ -28,7 +29,7 @@ type USPDevice struct { configurer WGConfigurer } -func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { +func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { log.Infof("using userspace bind mode") return &USPDevice{ @@ -93,7 +94,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *USPDevice) UpdateAddr(address WGAddress) error { +func (t *USPDevice) UpdateAddr(address wgaddr.Address) error { t.address = address return t.assignAddr() } @@ -113,7 +114,7 @@ func (t *USPDevice) Close() error { return nil } -func (t *USPDevice) WgAddress() WGAddress { +func (t *USPDevice) WgAddress() wgaddr.Address { return t.address } diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index 0fd1b3326..41e615bc2 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -13,13 +13,14 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}" type TunDevice struct { name string - address WGAddress + address wgaddr.Address port int key string mtu int @@ -32,7 +33,7 @@ type TunDevice struct { configurer WGConfigurer } -func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { +func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { return &TunDevice{ name: name, address: address, @@ -118,7 +119,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *TunDevice) UpdateAddr(address WGAddress) error { +func (t *TunDevice) UpdateAddr(address wgaddr.Address) error { t.address = address return t.assignAddr() } @@ -139,7 +140,7 @@ func (t *TunDevice) Close() error { } return nil } -func (t *TunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() wgaddr.Address { return t.address } diff --git a/client/iface/device/wg_link_freebsd.go b/client/iface/device/wg_link_freebsd.go index 104010f47..9067790e4 100644 --- a/client/iface/device/wg_link_freebsd.go +++ b/client/iface/device/wg_link_freebsd.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/freebsd" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) type wgLink struct { @@ -56,7 +57,7 @@ func (l *wgLink) up() error { return nil } -func (l *wgLink) assignAddr(address WGAddress) error { +func (l *wgLink) assignAddr(address wgaddr.Address) error { link, err := freebsd.LinkByName(l.name) if err != nil { return fmt.Errorf("link by name: %w", err) diff --git a/client/iface/device/wg_link_linux.go b/client/iface/device/wg_link_linux.go index a15cffe48..d941cd022 100644 --- a/client/iface/device/wg_link_linux.go +++ b/client/iface/device/wg_link_linux.go @@ -8,6 +8,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/vishvananda/netlink" + + "github.com/netbirdio/netbird/client/iface/wgaddr" ) type wgLink struct { @@ -90,7 +92,7 @@ func (l *wgLink) up() error { return nil } -func (l *wgLink) assignAddr(address WGAddress) error { +func (l *wgLink) assignAddr(address wgaddr.Address) error { //delete existing addresses list, err := netlink.AddrList(l, 0) if err != nil { diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 5cbeb70f8..a1e246fc5 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -7,13 +7,14 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) type WGTunDevice interface { Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) Up() (*bind.UniversalUDPMuxDefault, error) - UpdateAddr(address WGAddress) error - WgAddress() WGAddress + UpdateAddr(address wgaddr.Address) error + WgAddress() wgaddr.Address DeviceName() string Close() error FilteredDevice() *device.FilteredDevice diff --git a/client/iface/iface.go b/client/iface/iface.go index 40bd51fbb..9d5262aed 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) @@ -28,8 +29,6 @@ const ( WgInterfaceDefault = configurer.WgInterfaceDefault ) -type WGAddress = device.WGAddress - type wgProxyFactory interface { GetProxy() wgproxy.Proxy Free() error @@ -72,7 +71,7 @@ func (w *WGIface) Name() string { } // Address returns the interface address -func (w *WGIface) Address() device.WGAddress { +func (w *WGIface) Address() wgaddr.Address { return w.tun.WgAddress() } @@ -103,7 +102,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error { w.mu.Lock() defer w.mu.Unlock() - addr, err := device.ParseWGAddress(newAddr) + addr, err := wgaddr.ParseWGAddress(newAddr) if err != nil { return err } diff --git a/client/iface/iface_new_android.go b/client/iface/iface_new_android.go index 69a8d1fd4..35046b887 100644 --- a/client/iface/iface_new_android.go +++ b/client/iface/iface_new_android.go @@ -3,17 +3,18 @@ package iface import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(opts.Address) + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) if err != nil { return nil, err } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) wgIFace := &WGIface{ userspaceBind: true, diff --git a/client/iface/iface_new_darwin.go b/client/iface/iface_new_darwin.go index a92d74e0f..93fd7fd5c 100644 --- a/client/iface/iface_new_darwin.go +++ b/client/iface/iface_new_darwin.go @@ -6,17 +6,18 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(opts.Address) + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) if err != nil { return nil, err } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) var tun WGTunDevice if netstack.IsEnabled() { diff --git a/client/iface/iface_new_ios.go b/client/iface/iface_new_ios.go index 363f95e11..317ee0f46 100644 --- a/client/iface/iface_new_ios.go +++ b/client/iface/iface_new_ios.go @@ -5,17 +5,18 @@ package iface import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(opts.Address) + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) if err != nil { return nil, err } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) wgIFace := &WGIface{ tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd), diff --git a/client/iface/iface_new_unix.go b/client/iface/iface_new_unix.go index f10b17c9a..23ee7236f 100644 --- a/client/iface/iface_new_unix.go +++ b/client/iface/iface_new_unix.go @@ -8,12 +8,13 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(opts.Address) + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) if err != nil { return nil, err } @@ -21,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { wgIFace := &WGIface{} if netstack.IsEnabled() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.userspaceBind = true wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) @@ -34,7 +35,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { return wgIFace, nil } if device.ModuleTunIsLoaded() { - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.userspaceBind = true wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) diff --git a/client/iface/iface_new_windows.go b/client/iface/iface_new_windows.go index 2e6355496..413062940 100644 --- a/client/iface/iface_new_windows.go +++ b/client/iface/iface_new_windows.go @@ -4,16 +4,17 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/netstack" + wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) // NewWGIFace Creates a new WireGuard interface instance func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { - wgAddress, err := device.ParseWGAddress(opts.Address) + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) if err != nil { return nil, err } - iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress) var tun WGTunDevice if netstack.IsEnabled() { diff --git a/client/iface/device/address.go b/client/iface/wgaddr/address.go similarity index 61% rename from client/iface/device/address.go rename to client/iface/wgaddr/address.go index 15de301da..e5079258c 100644 --- a/client/iface/device/address.go +++ b/client/iface/wgaddr/address.go @@ -1,29 +1,29 @@ -package device +package wgaddr import ( "fmt" "net" ) -// WGAddress WireGuard parsed address -type WGAddress struct { +// Address WireGuard parsed address +type Address struct { IP net.IP Network *net.IPNet } // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address -func ParseWGAddress(address string) (WGAddress, error) { +func ParseWGAddress(address string) (Address, error) { ip, network, err := net.ParseCIDR(address) if err != nil { - return WGAddress{}, err + return Address{}, err } - return WGAddress{ + return Address{ IP: ip, Network: network, }, nil } -func (addr WGAddress) String() string { +func (addr Address) String() string { maskSize, _ := addr.Network.Mask.Size() return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) } diff --git a/client/installer.nsis b/client/installer.nsis index 743c81a6d..5219058a8 100644 --- a/client/installer.nsis +++ b/client/installer.nsis @@ -22,6 +22,8 @@ !define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}" !define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}" +!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run" + Unicode True ###################################################################### @@ -68,6 +70,9 @@ ShowInstDetails Show !insertmacro MUI_PAGE_DIRECTORY +; Custom page for autostart checkbox +Page custom AutostartPage AutostartPageLeave + !insertmacro MUI_PAGE_INSTFILES !insertmacro MUI_PAGE_FINISH @@ -80,8 +85,36 @@ ShowInstDetails Show !insertmacro MUI_LANGUAGE "English" +; Variables for autostart option +Var AutostartCheckbox +Var AutostartEnabled + ###################################################################### +; Function to create the autostart options page +Function AutostartPage + !insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows." + + nsDialogs::Create 1018 + Pop $0 + + ${If} $0 == error + Abort + ${EndIf} + + ${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts" + Pop $AutostartCheckbox + ${NSD_Check} $AutostartCheckbox ; Default to checked + StrCpy $AutostartEnabled "1" ; Default to enabled + + nsDialogs::Show +FunctionEnd + +; Function to handle leaving the autostart page +Function AutostartPageLeave + ${NSD_GetState} $AutostartCheckbox $AutostartEnabled +FunctionEnd + Function GetAppFromCommand Exch $1 Push $2 @@ -163,6 +196,16 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}" WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}" +; Create autostart registry entry based on checkbox +DetailPrint "Autostart enabled: $AutostartEnabled" +${If} $AutostartEnabled == "1" + WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe" + DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe" +${Else} + DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" + DetailPrint "Autostart not enabled by user" +${EndIf} + EnVar::SetHKLM EnVar::AddValueEx "path" "$INSTDIR" @@ -186,7 +229,10 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' # kill ui client -ExecWait `taskkill /im ${UI_APP_EXE}.exe` +ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` + +; Remove autostart registry entry +DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" # wait the service uninstall take unblock the executable Sleep 3000 diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 9e1659455..b54a105b3 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -9,7 +9,7 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/netflow" mgmProto "github.com/netbirdio/netbird/management/proto" @@ -49,7 +49,7 @@ func TestDefaultManager(t *testing.T) { } ifaceMock.EXPECT().Name().Return("lo").AnyTimes() - ifaceMock.EXPECT().Address().Return(iface.WGAddress{ + ifaceMock.EXPECT().Address().Return(wgaddr.Address{ IP: ip, Network: network, }).AnyTimes() @@ -343,7 +343,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { } ifaceMock.EXPECT().Name().Return("lo").AnyTimes() - ifaceMock.EXPECT().Address().Return(iface.WGAddress{ + ifaceMock.EXPECT().Address().Return(wgaddr.Address{ IP: ip, Network: network, }).AnyTimes() diff --git a/client/internal/acl/mocks/iface_mapper.go b/client/internal/acl/mocks/iface_mapper.go index 08aa4fd5a..95d5a2c58 100644 --- a/client/internal/acl/mocks/iface_mapper.go +++ b/client/internal/acl/mocks/iface_mapper.go @@ -10,8 +10,8 @@ import ( gomock "github.com/golang/mock/gomock" wgdevice "golang.zx2c4.com/wireguard/device" - iface "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) // MockIFaceMapper is a mock of IFaceMapper interface. @@ -38,10 +38,10 @@ func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder { } // Address mocks base method. -func (m *MockIFaceMapper) Address() iface.WGAddress { +func (m *MockIFaceMapper) Address() wgaddr.Address { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Address") - ret0, _ := ret[0].(iface.WGAddress) + ret0, _ := ret[0].(wgaddr.Address) return ret0 } diff --git a/client/internal/connect.go b/client/internal/connect.go index bf513ed39..504c88c6f 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -61,7 +61,7 @@ func NewConnectClient( } // Run with main logic. -func (c *ConnectClient) Run(runningChan chan error) error { +func (c *ConnectClient) Run(runningChan chan struct{}) error { return c.run(MobileDependency{}, runningChan) } @@ -102,7 +102,7 @@ func (c *ConnectClient) RunOniOS( return c.run(mobileDependency, nil) } -func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error { +func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error { defer func() { if r := recover(); r != nil { rec := c.statusRecorder @@ -159,10 +159,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } defer c.statusRecorder.ClientStop() - runningChanOpen := true operation := func() error { // if context cancelled we not start new backoff cycle - if c.isContextCancelled() { + if c.ctx.Err() != nil { return nil } @@ -282,10 +281,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) - if runningChan != nil && runningChanOpen { - runningChan <- nil - close(runningChan) - runningChanOpen = false + if runningChan != nil { + select { + case runningChan <- struct{}{}: + default: + } } <-engineCtx.Done() @@ -379,15 +379,6 @@ func (c *ConnectClient) Stop() error { return nil } -func (c *ConnectClient) isContextCancelled() bool { - select { - case <-c.ctx.Done(): - return true - default: - return false - } -} - // SetNetworkMapPersistence enables or disables network map persistence. // When enabled, the last received network map will be stored and can be retrieved // through the Engine's getLatestNetworkMap method. When disabled, any stored diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index d60edfa55..8871158ed 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -22,6 +22,7 @@ import ( "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" pfmock "github.com/netbirdio/netbird/client/iface/mocks" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -40,9 +41,9 @@ func (w *mocWGIface) Name() string { panic("implement me") } -func (w *mocWGIface) Address() iface.WGAddress { +func (w *mocWGIface) Address() wgaddr.Address { ip, network, _ := net.ParseCIDR("100.66.100.0/24") - return iface.WGAddress{ + return wgaddr.Address{ IP: ip, Network: network, } diff --git a/client/internal/dns/wgiface.go b/client/internal/dns/wgiface.go index 69bc83659..c6c1752e5 100644 --- a/client/internal/dns/wgiface.go +++ b/client/internal/dns/wgiface.go @@ -5,15 +5,15 @@ package dns import ( "net" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) // WGIface defines subset methods of interface required for manager type WGIface interface { Name() string - Address() iface.WGAddress + Address() wgaddr.Address ToInterface() *net.Interface IsUserspaceBind() bool GetFilter() device.PacketFilter diff --git a/client/internal/dns/wgiface_windows.go b/client/internal/dns/wgiface_windows.go index 765132fdb..74e5c75a5 100644 --- a/client/internal/dns/wgiface_windows.go +++ b/client/internal/dns/wgiface_windows.go @@ -1,15 +1,15 @@ package dns import ( - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) // WGIface defines subset methods of interface required for manager type WGIface interface { Name() string - Address() iface.WGAddress + Address() wgaddr.Address IsUserspaceBind() bool GetFilter() device.PacketFilter GetDevice() *device.FilteredDevice diff --git a/client/internal/engine.go b/client/internal/engine.go index e54ff35f6..2627e5232 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1641,16 +1641,19 @@ func (e *Engine) probeTURNs() []relay.ProbeResult { return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns) } +// restartEngine restarts the engine by cancelling the client context func (e *Engine) restartEngine() { - log.Info("restarting engine") - CtxGetState(e.ctx).Set(StatusConnecting) + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() - if err := e.Stop(); err != nil { - log.Errorf("Failed to stop engine: %v", err) + if e.ctx.Err() != nil { + return } + log.Info("restarting engine") + CtxGetState(e.ctx).Set(StatusConnecting) _ = CtxGetState(e.ctx).Wrap(ErrResetConnection) - log.Infof("cancelling client, engine will be recreated") + log.Infof("cancelling client context, engine will be recreated") e.clientCancel() } @@ -1662,34 +1665,17 @@ func (e *Engine) startNetworkMonitor() { e.networkMonitor = networkmonitor.New() go func() { - var mu sync.Mutex - var debounceTimer *time.Timer - - // Start the network monitor with a callback, Start will block until the monitor is stopped, - // a network change is detected, or an error occurs on start up - err := e.networkMonitor.Start(e.ctx, func() { - // This function is called when a network change is detected - mu.Lock() - defer mu.Unlock() - - if debounceTimer != nil { - log.Infof("Network monitor: detected network change, reset debounceTimer") - debounceTimer.Stop() + if err := e.networkMonitor.Listen(e.ctx); err != nil { + if errors.Is(err, context.Canceled) { + log.Infof("network monitor stopped") + return } - - // Set a new timer to debounce rapid network changes - debounceTimer = time.AfterFunc(2*time.Second, func() { - // This function is called after the debounce period - mu.Lock() - defer mu.Unlock() - - log.Infof("Network monitor: detected network change, restarting engine") - e.restartEngine() - }) - }) - if err != nil && !errors.Is(err, networkmonitor.ErrStopped) { - log.Errorf("Network monitor: %v", err) + log.Errorf("network monitor error: %v", err) + return } + + log.Infof("Network monitor: detected network change, restarting engine") + e.restartEngine() }() } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 60f07cbec..828823de8 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -31,6 +31,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" @@ -75,7 +76,7 @@ type MockWGIface struct { CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error IsUserspaceBindFunc func() bool NameFunc func() string - AddressFunc func() device.WGAddress + AddressFunc func() wgaddr.Address ToInterfaceFunc func() *net.Interface UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpdateAddrFunc func(newAddr string) error @@ -114,7 +115,7 @@ func (m *MockWGIface) Name() string { return m.NameFunc() } -func (m *MockWGIface) Address() device.WGAddress { +func (m *MockWGIface) Address() wgaddr.Address { return m.AddressFunc() } @@ -364,8 +365,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { RemovePeerFunc: func(peerKey string) error { return nil }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ IP: net.ParseIP("10.20.0.1"), Network: &net.IPNet{ IP: net.ParseIP("10.20.0.0"), diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index 65b425015..ffeffaf41 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) @@ -20,7 +21,7 @@ type wgIfaceBase interface { CreateOnAndroid(routeRange []string, ip string, domains []string) error IsUserspaceBind() bool Name() string - Address() device.WGAddress + Address() wgaddr.Address ToInterface() *net.Interface Up() (*bind.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/check_change_bsd.go similarity index 90% rename from client/internal/networkmonitor/monitor_bsd.go rename to client/internal/networkmonitor/check_change_bsd.go index 4dc2c1aa3..bb327a877 100644 --- a/client/internal/networkmonitor/monitor_bsd.go +++ b/client/internal/networkmonitor/check_change_bsd.go @@ -16,7 +16,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) -func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) if err != nil { return fmt.Errorf("failed to open routing socket: %v", err) @@ -28,18 +28,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca } }() - go func() { - <-ctx.Done() - err := unix.Close(fd) - if err != nil && !errors.Is(err, unix.EBADF) { - log.Debugf("Network monitor: closed routing socket: %v", err) - } - }() - for { select { case <-ctx.Done(): - return ErrStopped + return ctx.Err() default: buf := make([]byte, 2048) n, err := unix.Read(fd, buf) @@ -76,11 +68,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca switch msg.Type { case unix.RTM_ADD: log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) - go callback() + return nil case unix.RTM_DELETE: if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) - go callback() + return nil } } } diff --git a/client/internal/networkmonitor/monitor_linux.go b/client/internal/networkmonitor/check_change_linux.go similarity index 93% rename from client/internal/networkmonitor/monitor_linux.go rename to client/internal/networkmonitor/check_change_linux.go index 035be1f09..efd8b5884 100644 --- a/client/internal/networkmonitor/monitor_linux.go +++ b/client/internal/networkmonitor/check_change_linux.go @@ -14,7 +14,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) -func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { if nexthopv4.Intf == nil && nexthopv6.Intf == nil { return errors.New("no interfaces available") } @@ -31,8 +31,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca for { select { case <-ctx.Done(): - return ErrStopped - + return ctx.Err() // handle route changes case route := <-routeChan: // default route and main table @@ -43,12 +42,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca // triggered on added/replaced routes case syscall.RTM_NEWROUTE: log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex) - go callback() return nil case syscall.RTM_DELROUTE: if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) { log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex) - go callback() return nil } } diff --git a/client/internal/networkmonitor/monitor_windows.go b/client/internal/networkmonitor/check_change_windows.go similarity index 89% rename from client/internal/networkmonitor/monitor_windows.go rename to client/internal/networkmonitor/check_change_windows.go index cd48c269d..582865738 100644 --- a/client/internal/networkmonitor/monitor_windows.go +++ b/client/internal/networkmonitor/check_change_windows.go @@ -10,7 +10,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) -func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { routeMonitor, err := systemops.NewRouteMonitor(ctx) if err != nil { return fmt.Errorf("failed to create route monitor: %w", err) @@ -24,20 +24,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca for { select { case <-ctx.Done(): - return ErrStopped + return ctx.Err() case route := <-routeMonitor.RouteUpdates(): if route.Destination.Bits() != 0 { continue } - if routeChanged(route, nexthopv4, nexthopv6, callback) { - break + if routeChanged(route, nexthopv4, nexthopv6) { + return nil } } } } -func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool { +func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool { intf := "" if route.Interface != nil { intf = route.Interface.Name @@ -51,18 +51,15 @@ func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Ne case systemops.RouteModified: // TODO: get routing table to figure out if our route is affected for modified routes log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf) - go callback() return true case systemops.RouteAdded: if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP { log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf) - go callback() return true } case systemops.RouteDeleted: if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP { log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf) - go callback() return true } } diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go index 5475455c6..5896b66b6 100644 --- a/client/internal/networkmonitor/monitor.go +++ b/client/internal/networkmonitor/monitor.go @@ -1,12 +1,27 @@ +//go:build !ios && !android + package networkmonitor import ( "context" "errors" + "fmt" + "net/netip" + "runtime/debug" "sync" + "time" + + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) -var ErrStopped = errors.New("monitor has been stopped") +const ( + debounceTime = 2 * time.Second +) + +var checkChangeFn = checkChange // NetworkMonitor watches for changes in network configuration. type NetworkMonitor struct { @@ -19,3 +34,99 @@ type NetworkMonitor struct { func New() *NetworkMonitor { return &NetworkMonitor{} } + +// Listen begins monitoring network changes. When a change is detected, this function will return without error. +func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) { + nw.mu.Lock() + if nw.cancel != nil { + nw.mu.Unlock() + return errors.New("network monitor already started") + } + + ctx, nw.cancel = context.WithCancel(ctx) + defer nw.cancel() + nw.wg.Add(1) + nw.mu.Unlock() + + defer nw.wg.Done() + + var nexthop4, nexthop6 systemops.Nexthop + + operation := func() error { + var errv4, errv6 error + nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified()) + nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified()) + + if errv4 != nil && errv6 != nil { + return errors.New("failed to get default next hops") + } + + if errv4 == nil { + log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name) + } + if errv6 == nil { + log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name) + } + + // continue if either route was found + return nil + } + + expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) + + if err := backoff.Retry(operation, expBackOff); err != nil { + return fmt.Errorf("failed to get default next hops: %w", err) + } + + // recover in case sys ops panic + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack()) + } + }() + + event := make(chan struct{}, 1) + go nw.checkChanges(ctx, event, nexthop4, nexthop6) + + // debounce changes + timer := time.NewTimer(0) + timer.Stop() + for { + select { + case <-event: + timer.Reset(debounceTime) + case <-timer.C: + return nil + case <-ctx.Done(): + timer.Stop() + return ctx.Err() + } + } +} + +// Stop stops the network monitor. +func (nw *NetworkMonitor) Stop() { + nw.mu.Lock() + defer nw.mu.Unlock() + + if nw.cancel == nil { + return + } + + nw.cancel() + nw.wg.Wait() +} + +func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) { + for { + if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil { + close(event) + return + } + // prevent blocking + select { + case event <- struct{}{}: + default: + } + } +} diff --git a/client/internal/networkmonitor/monitor_generic.go b/client/internal/networkmonitor/monitor_generic.go deleted file mode 100644 index 19648edba..000000000 --- a/client/internal/networkmonitor/monitor_generic.go +++ /dev/null @@ -1,82 +0,0 @@ -//go:build !ios && !android - -package networkmonitor - -import ( - "context" - "errors" - "fmt" - "net/netip" - "runtime/debug" - - "github.com/cenkalti/backoff/v4" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/internal/routemanager/systemops" -) - -// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns. -func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) { - if ctx.Err() != nil { - return ctx.Err() - } - - nw.mu.Lock() - ctx, nw.cancel = context.WithCancel(ctx) - nw.mu.Unlock() - - nw.wg.Add(1) - defer nw.wg.Done() - - var nexthop4, nexthop6 systemops.Nexthop - - operation := func() error { - var errv4, errv6 error - nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified()) - nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified()) - - if errv4 != nil && errv6 != nil { - return errors.New("failed to get default next hops") - } - - if errv4 == nil { - log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name) - } - if errv6 == nil { - log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name) - } - - // continue if either route was found - return nil - } - - expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) - - if err := backoff.Retry(operation, expBackOff); err != nil { - return fmt.Errorf("failed to get default next hops: %w", err) - } - - // recover in case sys ops panic - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack()) - } - }() - - if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil { - return fmt.Errorf("check change: %w", err) - } - - return nil -} - -// Stop stops the network monitor. -func (nw *NetworkMonitor) Stop() { - nw.mu.Lock() - defer nw.mu.Unlock() - - if nw.cancel != nil { - nw.cancel() - nw.wg.Wait() - } -} diff --git a/client/internal/networkmonitor/monitor_mobile.go b/client/internal/networkmonitor/monitor_mobile.go index c81fad16c..861dbbe3c 100644 --- a/client/internal/networkmonitor/monitor_mobile.go +++ b/client/internal/networkmonitor/monitor_mobile.go @@ -2,10 +2,21 @@ package networkmonitor -import "context" +import ( + "context" + "fmt" +) -func (nw *NetworkMonitor) Start(context.Context, func()) error { - return nil +type NetworkMonitor struct { +} + +// New creates a new network monitor. +func New() *NetworkMonitor { + return &NetworkMonitor{} +} + +func (nw *NetworkMonitor) Listen(_ context.Context) error { + return fmt.Errorf("network monitor not supported on mobile platforms") } func (nw *NetworkMonitor) Stop() { diff --git a/client/internal/networkmonitor/monitor_test.go b/client/internal/networkmonitor/monitor_test.go new file mode 100644 index 000000000..164686689 --- /dev/null +++ b/client/internal/networkmonitor/monitor_test.go @@ -0,0 +1,99 @@ +package networkmonitor + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +type MocMultiEvent struct { + counter int +} + +func (m *MocMultiEvent) checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + if m.counter == 0 { + <-ctx.Done() + return ctx.Err() + } + + time.Sleep(1 * time.Second) + m.counter-- + return nil +} + +func TestNetworkMonitor_Close(t *testing.T) { + checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + <-ctx.Done() + return ctx.Err() + } + nw := New() + + var resErr error + done := make(chan struct{}) + go func() { + resErr = nw.Listen(context.Background()) + close(done) + }() + + time.Sleep(1 * time.Second) // wait for the goroutine to start + nw.Stop() + + <-done + if !errors.Is(resErr, context.Canceled) { + t.Errorf("unexpected error: %v", resErr) + } +} + +func TestNetworkMonitor_Event(t *testing.T) { + checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + timeout, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timeout.Done(): + return nil + } + } + nw := New() + defer nw.Stop() + + var resErr error + done := make(chan struct{}) + go func() { + resErr = nw.Listen(context.Background()) + close(done) + }() + + <-done + if !errors.Is(resErr, nil) { + t.Errorf("unexpected error: %v", nil) + } +} + +func TestNetworkMonitor_MultiEvent(t *testing.T) { + eventsRepeated := 3 + me := &MocMultiEvent{counter: eventsRepeated} + checkChangeFn = me.checkChange + + nw := New() + defer nw.Stop() + + done := make(chan struct{}) + started := time.Now() + go func() { + if resErr := nw.Listen(context.Background()); resErr != nil { + t.Errorf("unexpected error: %v", resErr) + } + close(done) + }() + + <-done + expectedResponseTime := time.Duration(eventsRepeated)*time.Second + debounceTime + if time.Since(started) < expectedResponseTime { + t.Errorf("unexpected duration: %v", time.Since(started)) + } +} diff --git a/client/internal/peer/iface.go b/client/internal/peer/iface.go index c7b6de9ea..32ac5c7db 100644 --- a/client/internal/peer/iface.go +++ b/client/internal/peer/iface.go @@ -8,6 +8,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" ) @@ -16,4 +17,5 @@ type WGIface interface { RemovePeer(peerKey string) error GetStats(peerKey string) (configurer.WGStats, error) GetProxy() wgproxy.Proxy + Address() wgaddr.Address } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 7dd84a98e..5ceb3f453 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -358,6 +358,12 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive } func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { + addr, err := netip.ParseAddr(candidate.Address()) + if err != nil { + log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err) + return false + } + var routePrefixes []netip.Prefix for _, routes := range clientRoutes { if len(routes) > 0 && routes[0] != nil { @@ -365,14 +371,8 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool } } - addr, err := netip.ParseAddr(candidate.Address()) - if err != nil { - log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err) - return false - } - for _, prefix := range routePrefixes { - // default route is + // default route is handled by route exclusion / ip rules if prefix.Bits() == 0 { continue } diff --git a/client/internal/routemanager/iface/iface_common.go b/client/internal/routemanager/iface/iface_common.go index 8b2dc9714..9e1f8058a 100644 --- a/client/internal/routemanager/iface/iface_common.go +++ b/client/internal/routemanager/iface/iface_common.go @@ -3,9 +3,9 @@ package iface import ( "net" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" ) type wgIfaceBase interface { @@ -13,7 +13,7 @@ type wgIfaceBase interface { RemoveAllowedIP(peerKey string, allowedIP string) error Name() string - Address() iface.WGAddress + Address() wgaddr.Address ToInterface() *net.Interface IsUserspaceBind() bool GetFilter() device.PacketFilter diff --git a/client/server/panic_windows.go b/client/server/panic_windows.go index 1d4ba4b75..c5e73be7c 100644 --- a/client/server/panic_windows.go +++ b/client/server/panic_windows.go @@ -3,7 +3,7 @@ package server import ( "fmt" "os" - "path/filepath" + "path" "syscall" log "github.com/sirupsen/logrus" @@ -12,7 +12,6 @@ import ( ) const ( - windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG" // STD_ERROR_HANDLE ((DWORD)-12) = 4294967284 stdErrorHandle = ^uintptr(11) ) @@ -25,13 +24,10 @@ var ( ) func handlePanicLog() error { - logPath := os.Getenv(windowsPanicLogEnvVar) - if logPath == "" { - return nil - } + // TODO: move this to a central location + logDir := path.Join(os.Getenv("PROGRAMDATA"), "Netbird") + logPath := path.Join(logDir, "netbird.err") - // Ensure the directory exists - logDir := filepath.Dir(logPath) if err := os.MkdirAll(logDir, 0750); err != nil { return fmt.Errorf("create panic log directory: %w", err) } @@ -39,13 +35,11 @@ func handlePanicLog() error { return fmt.Errorf("enforce permission on panic log file: %w", err) } - // Open log file with append mode f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) if err != nil { return fmt.Errorf("open panic log file: %w", err) } - // Redirect stderr to the file if err = redirectStderr(f); err != nil { if closeErr := f.Close(); closeErr != nil { log.Warnf("failed to close file after redirect error: %v", closeErr) @@ -59,7 +53,6 @@ func handlePanicLog() error { // redirectStderr redirects stderr to the provided file func redirectStderr(f *os.File) error { - // Get the current process's stderr handle if err := setStdHandle(f); err != nil { return fmt.Errorf("failed to set stderr handle: %w", err) } diff --git a/client/server/server.go b/client/server/server.go index 8907f541f..2d8f759cd 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -160,7 +160,7 @@ func (s *Server) Start() error { // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, - runningChan chan error, + runningChan chan struct{}, ) { backOff := getConnectWithBackoff(ctx) retryStarted := false @@ -628,20 +628,21 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) - runningChan := make(chan error) - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan) + timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) + defer cancel() + runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan) for { select { - case err := <-runningChan: - if err != nil { - log.Debugf("waiting for engine to become ready failed: %s", err) - } else { - return &proto.UpResponse{}, nil - } + case <-runningChan: + return &proto.UpResponse{}, nil case <-callerCtx.Done(): log.Debug("context done, stopping the wait for engine to become ready") return nil, callerCtx.Err() + case <-timeoutCtx.Done(): + log.Debug("up is timed out, stopping the wait for engine to become ready") + return nil, timeoutCtx.Err() } } } diff --git a/infrastructure_files/observability/grafana/dashboards/signal.json b/infrastructure_files/observability/grafana/dashboards/signal.json index 5e36f6ce6..0dc1b7aa6 100644 --- a/infrastructure_files/observability/grafana/dashboards/signal.json +++ b/infrastructure_files/observability/grafana/dashboards/signal.json @@ -757,7 +757,7 @@ }, "id": 18, "panels": [], - "title": "Core metrics / registerations", + "title": "Core metrics / registrations", "type": "row" }, { @@ -1874,4 +1874,4 @@ "uid": "cebyq0fs0m-v001", "version": 15, "weekStart": "" - } \ No newline at end of file + }