Compare commits

...

7 Commits

Author SHA1 Message Date
Hakan Sariman
efe4ed8c1b Refactor DNS server update to use structured hash type for configuration 2025-09-18 15:05:41 +07:00
Bethuel Mmbaga
3130cce72d [management] Add rule ID validation for policy updates (#4499) 2025-09-15 21:08:16 +03:00
Zoltan Papp
bd23ab925e [client] Fix ICE latency handling (#4501)
The GetSelectedCandidatePair() does not carry the latency information.
2025-09-15 15:08:53 +02:00
Zoltan Papp
0c6f671a7c Refactor healthcheck sender and receiver to use configurable options (#4433) 2025-09-12 09:31:03 +02:00
Bethuel Mmbaga
cf7f6c355f [misc] Remove default zitadel admin user in deployment script (#4482)
* Delete default zitadel-admin user during initialization

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-09-11 21:20:10 +02:00
Zoltan Papp
47e64d72db [client] Fix client status check (#4474)
The client status is not enough to protect the RPC calls from concurrency issues, because it is handled internally in the client in an asynchronous way.
2025-09-11 16:21:09 +02:00
Zoltan Papp
9e81e782e5 [client] Fix/v4 stun routing (#4430)
Deduplicate STUN package sending.
Originally, because every peer shared the same UDP address, the library could not distinguish which STUN message was associated with which candidate. As a result, the Pion library responded from all candidates for every STUN message.
2025-09-11 10:08:54 +02:00
40 changed files with 772 additions and 575 deletions

View File

@@ -217,7 +217,7 @@ jobs:
- arch: "386" - arch: "386"
raceFlag: "" raceFlag: ""
- arch: "amd64" - arch: "amd64"
raceFlag: "" raceFlag: "-race"
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go

View File

@@ -15,6 +15,7 @@ import (
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@@ -44,7 +45,7 @@ type ICEBind struct {
RecvChan chan RecvMessage RecvChan chan RecvMessage
transportNet transport.Net transportNet transport.Net
filterFn FilterFn filterFn udpmux.FilterFn
endpoints map[netip.Addr]net.Conn endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a // every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
@@ -54,13 +55,13 @@ type ICEBind struct {
closed bool closed bool
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
address wgaddr.Address address wgaddr.Address
mtu uint16 mtu uint16
activityRecorder *ActivityRecorder activityRecorder *ActivityRecorder
} }
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
@@ -115,7 +116,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
} }
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind // GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) { func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
if s.udpMux == nil { if s.udpMux == nil {
@@ -158,8 +159,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.muUDPMux.Lock() s.muUDPMux.Lock()
defer s.muUDPMux.Unlock() defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = udpmux.NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn), UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,

View File

@@ -1,7 +0,0 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -7,14 +7,14 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create() (device.WGConfigurer, error) Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address
MTU() uint16 MTU() uint16

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -29,7 +30,7 @@ type WGTunDevice struct {
name string name string
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
} }
return t.configurer, nil return t.configurer, nil
} }
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -26,7 +27,7 @@ type TunDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -28,7 +29,7 @@ type TunDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -83,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -12,8 +12,8 @@ import (
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock" "github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -31,9 +31,9 @@ type TunKernelDevice struct {
link *wgLink link *wgLink
udpMuxConn net.PacketConn udpMuxConn net.PacketConn
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
filterFn bind.FilterFn filterFn udpmux.FilterFn
} }
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice { func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
return configurer, nil return configurer, nil
} }
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.udpMux != nil { if t.udpMux != nil {
return t.udpMux, nil return t.udpMux, nil
} }
@@ -106,14 +106,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
udpConn = nbnet.WrapPacketConn(rawSock) udpConn = nbnet.WrapPacketConn(rawSock)
} }
bindParams := bind.UniversalUDPMuxParams{ bindParams := udpmux.UniversalUDPMuxParams{
UDPConn: udpConn, UDPConn: udpConn,
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn, FilterFn: t.filterFn,
WGAddress: t.address, WGAddress: t.address,
MTU: t.mtu, MTU: t.mtu,
} }
mux := bind.NewUniversalUDPMuxDefault(bindParams) mux := udpmux.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx) go mux.ReadFromConn(t.ctx)
t.udpMuxConn = rawSock t.udpMuxConn = rawSock
t.udpMux = mux t.udpMux = mux

View File

@@ -10,6 +10,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@@ -26,7 +27,7 @@ type TunNetstackDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
nsTun *nbnetstack.NetStackTun nsTun *nbnetstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
net *netstack.Net net *netstack.Net
@@ -80,7 +81,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil { if t.device == nil {
return nil, fmt.Errorf("device is not ready yet") return nil, fmt.Errorf("device is not ready yet")
} }

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -25,7 +26,7 @@ type USPDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil { if t.device == nil {
return nil, fmt.Errorf("device is not ready yet") return nil, fmt.Errorf("device is not ready yet")
} }

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -29,7 +30,7 @@ type TunDevice struct {
device *device.Device device *device.Device
nativeTunDevice *tun.NativeTun nativeTunDevice *tun.NativeTun
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
} }
@@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -5,14 +5,14 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address
MTU() uint16 MTU() uint16

View File

@@ -16,9 +16,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
@@ -61,7 +61,7 @@ type WGIFaceOpts struct {
MTU uint16 MTU uint16
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn bind.FilterFn FilterFn udpmux.FilterFn
DisableDNS bool DisableDNS bool
} }
@@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface {
// Up configures a Wireguard interface // Up configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before) // The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) { func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()

View File

@@ -1,4 +1,4 @@
package bind package udpmux
/* /*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
@@ -16,11 +16,12 @@ import (
) )
type udpMuxedConnParams struct { type udpMuxedConnParams struct {
Mux *UDPMuxDefault Mux *SingleSocketUDPMux
AddrPool *sync.Pool AddrPool *sync.Pool
Key string Key string
LocalAddr net.Addr LocalAddr net.Addr
Logger logging.LeveledLogger Logger logging.LeveledLogger
CandidateID string
} }
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag // udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
@@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error {
return err return err
} }
func (c *udpMuxedConn) GetCandidateID() string {
return c.params.CandidateID
}
func (c *udpMuxedConn) isClosed() bool { func (c *udpMuxedConn) isClosed() bool {
select { select {
case <-c.closedChan: case <-c.closedChan:

View File

@@ -0,0 +1,64 @@
// Package udpmux provides a custom implementation of a UDP multiplexer
// that allows multiple logical ICE connections to share a single underlying
// UDP socket. This is based on Pion's ICE library, with modifications for
// NetBird's requirements.
//
// # Background
//
// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
// Establishment) is responsible for discovering candidate network paths
// and maintaining connectivity between peers. Each ICE connection
// normally requires a dedicated UDP socket. However, using one socket
// per candidate can be inefficient and difficult to manage.
//
// This package introduces SingleSocketUDPMux, which allows multiple ICE
// candidate connections (muxed connections) to share a single UDP socket.
// It handles demultiplexing of packets based on ICE ufrag values, STUN
// attributes, and candidate IDs.
//
// # Usage
//
// The typical flow is:
//
// 1. Create a UDP socket (net.PacketConn).
// 2. Construct Params with the socket and optional logger/net stack.
// 3. Call NewSingleSocketUDPMux(params).
// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
// to obtain a logical PacketConn.
// 5. Use the returned PacketConn just like a normal UDP connection.
//
// # STUN Message Routing Logic
//
// When a STUN packet arrives, the mux decides which connection should
// receive it using this routing logic:
//
// Primary Routing: Candidate Pair ID
// - Extract the candidate pair ID from the STUN message using
// ice.CandidatePairIDFromSTUN(msg)
// - The target candidate is the locally generated candidate that
// corresponds to the connection that should handle this STUN message
// - If found, use the target candidate ID to lookup the specific
// connection in candidateConnMap
// - Route the message directly to that connection
//
// Fallback Routing: Broadcasting
// When candidate pair ID is not available or lookup fails:
// - Collect connections from addressMap based on source address
// - Find connection using username attribute (ufrag) from STUN message
// - Remove duplicate connections from the list
// - Send the STUN message to all collected connections
//
// # Peer Reflexive Candidate Discovery
//
// When a remote peer sends a STUN message from an unknown source address
// (from a candidate that has not been exchanged via signal), the ICE
// library will:
// - Generate a new peer reflexive candidate for this source address
// - Extract or assign a candidate ID based on the STUN message attributes
// - Create a mapping between the new peer reflexive candidate ID and
// the appropriate local connection
//
// This discovery mechanism ensures that STUN messages from newly discovered
// peer reflexive candidates can be properly routed to the correct local
// connection without requiring fallback broadcasting.
package udpmux

View File

@@ -1,4 +1,4 @@
package bind package udpmux
import ( import (
"fmt" "fmt"
@@ -22,9 +22,9 @@ import (
const receiveMTU = 8192 const receiveMTU = 8192
// UDPMuxDefault is an implementation of the interface // SingleSocketUDPMux is an implementation of the interface
type UDPMuxDefault struct { type SingleSocketUDPMux struct {
params UDPMuxParams params Params
closedChan chan struct{} closedChan chan struct{}
closeOnce sync.Once closeOnce sync.Once
@@ -32,6 +32,9 @@ type UDPMuxDefault struct {
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType // connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn connsIPv4, connsIPv6 map[string]*udpMuxedConn
// candidateConnMap maps local candidate IDs to their corresponding connection.
candidateConnMap map[string]*udpMuxedConn
addressMapMu sync.RWMutex addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn addressMap map[string][]*udpMuxedConn
@@ -46,8 +49,8 @@ type UDPMuxDefault struct {
const maxAddrSize = 512 const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux. // Params are parameters for UDPMux.
type UDPMuxParams struct { type Params struct {
Logger logging.LeveledLogger Logger logging.LeveledLogger
UDPConn net.PacketConn UDPConn net.PacketConn
@@ -147,17 +150,18 @@ func isZeros(ip net.IP) bool {
return true return true
} }
// NewUDPMuxDefault creates an implementation of UDPMux // NewSingleSocketUDPMux creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
if params.Logger == nil { if params.Logger == nil {
params.Logger = getLogger() params.Logger = getLogger()
} }
mux := &UDPMuxDefault{ mux := &SingleSocketUDPMux{
addressMap: map[string][]*udpMuxedConn{}, addressMap: map[string][]*udpMuxedConn{},
params: params, params: params,
connsIPv4: make(map[string]*udpMuxedConn), connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn), connsIPv6: make(map[string]*udpMuxedConn),
candidateConnMap: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1), closedChan: make(chan struct{}, 1),
pool: &sync.Pool{ pool: &sync.Pool{
New: func() interface{} { New: func() interface{} {
@@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return mux return mux
} }
func (m *UDPMuxDefault) updateLocalAddresses() { func (m *SingleSocketUDPMux) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok { if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr()) m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() { } else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but // For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection // it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux. // with SingleSocketUDPMux, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType var networks []ice.NetworkType
switch { switch {
@@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() {
m.mu.Unlock() m.mu.Unlock()
} }
// LocalAddr returns the listening address of this UDPMuxDefault // LocalAddr returns the listening address of this SingleSocketUDPMux
func (m *UDPMuxDefault) LocalAddr() net.Addr { func (m *SingleSocketUDPMux) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr() return m.params.UDPConn.LocalAddr()
} }
// GetListenAddresses returns the list of addresses that this mux is listening on // GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
m.updateLocalAddresses() m.updateLocalAddresses()
m.mu.Lock() m.mu.Lock()
@@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// GetConn returns a PacketConn given the connection's ufrag and network address // GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found // creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) {
// don't check addr for mux using unspecified address // don't check addr for mux using unspecified address
m.mu.Lock() m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified) lenLocalAddrs := len(m.localAddrsForUnspecified)
@@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
return conn, nil return conn, nil
} }
c := m.createMuxedConn(ufrag) c := m.createMuxedConn(ufrag, candidateID)
go func() { go func() {
<-c.CloseChannel() <-c.CloseChannel()
m.RemoveConnByUfrag(ufrag) m.RemoveConnByUfrag(ufrag)
}() }()
m.candidateConnMap[candidateID] = c
if isIPv6 { if isIPv6 {
m.connsIPv6[ufrag] = c m.connsIPv6[ufrag] = c
} else { } else {
@@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
} }
// RemoveConnByUfrag stops and removes the muxed packet connection // RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2) removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock // Keep lock section small to avoid deadlock with conn lock
@@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
if c, ok := m.connsIPv4[ufrag]; ok { if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag) delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c) removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
} }
if c, ok := m.connsIPv6[ufrag]; ok { if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag) delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c) removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
} }
m.mu.Unlock() m.mu.Unlock()
@@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
} }
// IsClosed returns true if the mux had been closed // IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool { func (m *SingleSocketUDPMux) IsClosed() bool {
select { select {
case <-m.closedChan: case <-m.closedChan:
return true return true
@@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool {
} }
// Close the mux, no further connections could be created // Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error { func (m *SingleSocketUDPMux) Close() error {
var err error var err error
m.closeOnce.Do(func() { m.closeOnce.Do(func() {
m.mu.Lock() m.mu.Lock()
@@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error {
return err return err
} }
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) { func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr) return m.params.UDPConn.WriteTo(buf, rAddr)
} }
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) { func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() { if m.IsClosed() {
return return
} }
@@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
} }
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn { func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{ c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m, Mux: m,
Key: key, Key: key,
AddrPool: m.pool, AddrPool: m.pool,
LocalAddr: m.LocalAddr(), LocalAddr: m.LocalAddr(),
Logger: m.params.Logger, Logger: m.params.Logger,
CandidateID: candidateID,
}) })
return c return c
} }
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library // HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error { func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr) remoteAddr, ok := addr.(*net.UDPAddr)
if !ok { if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr") return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
} }
// If we have already seen this address dispatch to the appropriate destination // Try to route to specific candidate connection first
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one if conn := m.findCandidateConnection(msg); conn != nil {
// muxed connection - one for the SRFLX candidate and the other one for the HOST one. return conn.writePacket(msg.Raw, remoteAddr)
// We will then forward STUN packets to each of these connections. }
m.addressMapMu.RLock()
// Fallback: route to all possible connections
return m.forwardToAllConnections(msg, addr, remoteAddr)
}
// findCandidateConnection attempts to find the specific connection for a STUN message
func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
if err != nil {
return nil
} else if !ok {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
if !exists {
return nil
}
return conn
}
// forwardToAllConnections forwards STUN message to all relevant connections
func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
var destinationConnList []*udpMuxedConn var destinationConnList []*udpMuxedConn
// Add connections from address map
m.addressMapMu.RLock()
if storedConns, ok := m.addressMap[addr.String()]; ok { if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...) destinationConnList = append(destinationConnList, storedConns...)
} }
m.addressMapMu.RUnlock() m.addressMapMu.RUnlock()
var isIPv6 bool if conn, ok := m.findConnectionByUsername(msg, addr); ok {
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { // If we have already seen this address dispatch to the appropriate destination
isIPv6 = true // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
if !m.connectionExists(conn, destinationConnList) {
destinationConnList = append(destinationConnList, conn)
}
} }
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront. // Forward to all found connections
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
for _, conn := range destinationConnList { for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil { if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err) log.Errorf("could not write packet: %v", err)
} }
} }
return nil return nil
} }
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) { // findConnectionByUsername finds connection using username attribute from STUN message
func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
attr, err := msg.Get(stun.AttrUsername)
if err != nil {
return nil, false
}
ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := isIPv6Address(addr)
m.mu.Lock()
defer m.mu.Unlock()
return m.getConn(ufrag, isIPv6)
}
// connectionExists checks if a connection already exists in the list
func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
for _, conn := range conns {
if conn.params.Key == target.params.Key {
return true
}
}
return false
}
func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 { if isIPv6 {
val, ok = m.connsIPv6[ufrag] val, ok = m.connsIPv6[ufrag]
} else { } else {
@@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
return return
} }
func isIPv6Address(addr net.Addr) bool {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
return udpAddr.IP.To4() == nil
}
return false
}
type bufferHolder struct { type bufferHolder struct {
buf []byte buf []byte
} }

View File

@@ -1,12 +1,12 @@
//go:build !ios //go:build !ios
package bind package udpmux
import ( import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet) // Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok { if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr) conn.RemoveAddress(addr)

View File

@@ -0,0 +1,7 @@
//go:build ios
package udpmux
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -1,4 +1,4 @@
package bind package udpmux
/* /*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements. Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
@@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing. // It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct { type UniversalUDPMuxDefault struct {
*UDPMuxDefault *SingleSocketUDPMux
params UniversalUDPMuxParams params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents // since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
@@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress, address: params.WGAddress,
} }
udpMuxParams := UDPMuxParams{ udpMuxParams := Params{
Logger: params.Logger, Logger: params.Logger,
UDPConn: m.params.UDPConn, UDPConn: m.params.UDPConn,
Net: m.params.Net, Net: m.params.Net,
} }
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams) m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams)
return m return m
} }
@@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers // GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server. // and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) { func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr) return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID)
} }
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server. // HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
@@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
} }
return nil return nil
} }
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr) return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr)
} }
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server. // isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.

View File

@@ -280,15 +280,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err) return wrapErr(err)
} }
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil { if runningChan != nil {
select { close(runningChan)
case runningChan <- struct{}{}: runningChan = nil
default:
}
} }
<-engineCtx.Done() <-engineCtx.Done()

View File

@@ -391,7 +391,15 @@ func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) erro
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{ hashType := struct {
ServiceEnable bool
NameServerGroups []*nbdns.NameServerGroup
}{
ServiceEnable: update.ServiceEnable,
NameServerGroups: update.NameServerGroups,
}
hash, err := hashstructure.Hash(hashType, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true, ZeroNil: true,
IgnoreZeroValue: true, IgnoreZeroValue: true,
SlicesAsSets: true, SlicesAsSets: true,

View File

@@ -29,9 +29,9 @@ import (
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
@@ -166,7 +166,7 @@ type Engine struct {
wgInterface WGIface wgInterface WGIface
udpMux *bind.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64 networkSerial uint64
@@ -461,7 +461,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
StunTurn: &e.stunTurn, StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList, InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery, DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault, UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux, UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
} }
@@ -1326,7 +1326,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
StunTurn: &e.stunTurn, StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList, InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery, DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault, UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux, UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(), NATExternalIPs: e.parseNATExternalIPMappings(),
}, },

View File

@@ -26,10 +26,11 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
@@ -84,7 +85,7 @@ type MockWGIface struct {
NameFunc func() string NameFunc func() string
AddressFunc func() wgaddr.Address AddressFunc func() wgaddr.Address
ToInterfaceFunc func() *net.Interface ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpFunc func() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error RemovePeerFunc func(peerKey string) error
@@ -134,7 +135,7 @@ func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc() return m.ToInterfaceFunc()
} }
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) { func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return m.UpFunc() return m.UpFunc()
} }
@@ -413,7 +414,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280}) engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
engine.ctx = ctx engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{}) engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface) engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)

View File

@@ -9,9 +9,9 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime" "github.com/netbirdio/netbird/monotime"
@@ -24,7 +24,7 @@ type wgIfaceBase interface {
Name() string Name() string
Address() wgaddr.Address Address() wgaddr.Address
ToInterface() *net.Interface ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error

View File

@@ -9,11 +9,10 @@ import (
"time" "time"
"github.com/pion/ice/v4" "github.com/pion/ice/v4"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype" "github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@@ -55,10 +54,6 @@ type WorkerICE struct {
sessionID ICESessionID sessionID ICESessionID
muxAgent sync.Mutex muxAgent sync.Mutex
StunTurn []*stun.URI
sentExtraSrflx bool
localUfrag string localUfrag string
localPwd string localPwd string
@@ -139,7 +134,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.muxAgent.Unlock() w.muxAgent.Unlock()
return return
} }
w.sentExtraSrflx = false
w.agent = agent w.agent = agent
w.agentDialerCancel = dialerCancel w.agentDialerCancel = dialerCancel
w.agentConnecting = true w.agentConnecting = true
@@ -166,6 +160,21 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
w.log.Errorf("error while handling remote candidate") w.log.Errorf("error while handling remote candidate")
return return
} }
if shouldAddExtraCandidate(candidate) {
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
}
} }
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) { func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
@@ -209,7 +218,9 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []
return nil, err return nil, err
} }
if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil { if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) {
w.onICESelectedCandidatePair(agent, c1, c2)
}); err != nil {
return nil, err return nil, err
} }
@@ -327,7 +338,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int)
return return
} }
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault) mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault)
if !ok { if !ok {
w.log.Warn("invalid udp mux conversion") w.log.Warn("invalid udp mux conversion")
return return
@@ -354,48 +365,19 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
} }
}() }()
if !w.shouldSendExtraSrflxCandidate(candidate) {
return
} }
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
w.sentExtraSrflx = true
go func() {
err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
if err != nil {
w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
}
}()
}
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
w.config.Key) w.config.Key)
w.muxAgent.Lock() pairStat, ok := agent.GetSelectedCandidatePairStats()
if !ok {
pair, err := w.agent.GetSelectedCandidatePair() w.log.Warnf("failed to get selected candidate pair stats")
if err != nil {
w.log.Warnf("failed to get selected candidate pair: %s", err)
w.muxAgent.Unlock()
return return
} }
if pair == nil {
w.log.Warnf("selected candidate pair is nil, cannot proceed")
w.muxAgent.Unlock()
return
}
w.muxAgent.Unlock()
duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second)) duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second))
if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil { if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
w.log.Debugf("failed to update latency for peer: %s", err) w.log.Debugf("failed to update latency for peer: %s", err)
return return
@@ -424,22 +406,31 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
} }
} }
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true
}
return false
}
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
isControlling := w.config.LocalKey > w.config.Key if isController(w.config) {
if isControlling { return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else { } else {
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} }
} }
func shouldAddExtraCandidate(candidate ice.Candidate) bool {
if candidate.Type() != ice.CandidateTypeServerReflexive {
return false
}
if candidate.Port() == candidate.RelatedAddress().Port {
return false
}
// in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates
// in newer version we generate locally the extra candidate
if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok {
return false
}
return true
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress() relatedAdd := candidate.RelatedAddress()
ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
@@ -455,6 +446,10 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
} }
for _, e := range candidate.Extensions() { for _, e := range candidate.Extensions() {
// overwrite the original candidate ID with the new one to avoid candidate duplication
if e.Key == ice.ExtensionKeyCandidateID {
e.Value = candidate.ID()
}
if err := ec.AddExtension(e); err != nil { if err := ec.AddExtension(e); err != nil {
return nil, err return nil, err
} }

View File

@@ -65,6 +65,8 @@ type Server struct {
mutex sync.Mutex mutex sync.Mutex
config *profilemanager.Config config *profilemanager.Config
proto.UnimplementedDaemonServiceServer proto.UnimplementedDaemonServiceServer
clientRunning bool // protected by mutex
clientRunningChan chan struct{}
connectClient *internal.ConnectClient connectClient *internal.ConnectClient
@@ -103,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
func (s *Server) Start() error { func (s *Server) Start() error {
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
state := internal.CtxGetState(s.rootCtx) state := internal.CtxGetState(s.rootCtx)
if err := handlePanicLog(); err != nil { if err := handlePanicLog(); err != nil {
@@ -172,8 +175,12 @@ func (s *Server) Start() error {
return nil return nil
} }
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) if s.clientRunning {
return nil
}
s.clientRunning = true
s.clientRunningChan = make(chan struct{}, 1)
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan)
return nil return nil
} }
@@ -204,12 +211,22 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost. // mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) {
runningChan chan struct{}, defer func() {
) { s.mutex.Lock()
backOff := getConnectWithBackoff(ctx) s.clientRunning = false
retryStarted := false s.mutex.Unlock()
}()
if s.config.DisableAutoConnect {
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
return
}
backOff := getConnectWithBackoff(ctx)
go func() { go func() {
t := time.NewTicker(24 * time.Hour) t := time.NewTicker(24 * time.Hour)
for { for {
@@ -218,89 +235,32 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanage
t.Stop() t.Stop()
return return
case <-t.C: case <-t.C:
if retryStarted {
mgmtState := statusRecorder.GetManagementState() mgmtState := statusRecorder.GetManagementState()
signalState := statusRecorder.GetSignalState() signalState := statusRecorder.GetSignalState()
if mgmtState.Connected && signalState.Connected { if mgmtState.Connected && signalState.Connected {
log.Tracef("resetting status") log.Tracef("resetting status")
retryStarted = false backOff.Reset()
} else { } else {
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected) log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
} }
} }
} }
}
}() }()
runOperation := func() error { runOperation := func() error {
log.Tracef("running client connection") err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
err := s.connectClient.Run(runningChan)
if err != nil { if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err) log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
} }
if config.DisableAutoConnect { log.Tracef("client connection exited gracefully, do not need to retry")
return backoff.Permanent(err) return nil
} }
if !retryStarted { if err := backoff.Retry(runOperation, backOff); err != nil {
retryStarted = true log.Errorf("operation failed: %v", err)
backOff.Reset()
} }
log.Tracef("client connection exited")
return fmt.Errorf("client connection exited")
}
err := backoff.Retry(runOperation, backOff)
if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled {
log.Errorf("received an error when trying to connect: %v", err)
} else {
log.Tracef("retry canceled")
}
}
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
multiplier := defaultRetryMultiplier
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
// parse the multiplier from the environment variable string value to float64
value, err := strconv.ParseFloat(envValue, 64)
if err != nil {
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
} else {
multiplier = value
}
}
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: initialInterval,
RandomizationFactor: 1,
Multiplier: multiplier,
MaxInterval: maxInterval,
MaxElapsedTime: maxElapsedTime, // 14 days
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// parseEnvDuration parses the environment variable and returns the duration
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
if envValue := os.Getenv(envVar); envValue != "" {
if duration, err := time.ParseDuration(envValue); err == nil {
return duration
}
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
}
return defaultDuration
} }
// loginAttempt attempts to login using the provided information. it returns a status in case something fails // loginAttempt attempts to login using the provided information. it returns a status in case something fails
@@ -716,11 +676,14 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
defer cancel() defer cancel()
runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal if !s.clientRunning {
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan) s.clientRunning = true
s.clientRunningChan = make(chan struct{}, 1)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan)
}
for { for {
select { select {
case <-runningChan: case <-s.clientRunningChan:
s.isSessionActive.Store(true) s.isSessionActive.Store(true)
return &proto.UpResponse{}, nil return &proto.UpResponse{}, nil
case <-callerCtx.Done(): case <-callerCtx.Done():
@@ -1127,6 +1090,134 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
}, nil }, nil
} }
// AddProfile adds a new profile to the daemon.
func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.ProfileName == "" || msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{}, nil
}
// RemoveProfile removes a profile from the daemon.
func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
}
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{}, nil
}
// ListProfiles lists all profiles in the daemon.
func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided")
}
profiles, err := s.profileManager.ListProfiles(msg.Username)
if err != nil {
log.Errorf("failed to list profiles: %v", err)
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
response := &proto.ListProfilesResponse{
Profiles: make([]*proto.Profile, len(profiles)),
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Name: profile.Name,
IsActive: profile.IsActive,
}
}
return response, nil
}
// GetActiveProfile returns the active profile in the daemon.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
activeProfile, err := s.profileManager.GetActiveProfileState()
if err != nil {
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
return &proto.GetActiveProfileResponse{
ProfileName: activeProfile.Name,
Username: activeProfile.Username,
}, nil
}
// GetFeatures returns the features supported by the daemon.
func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
}
return features, nil
}
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
return err
}
return nil
}
func (s *Server) checkProfilesDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.profilesDisabled {
return true
}
return false
}
func (s *Server) checkUpdateSettingsDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.updateSettingsDisabled {
return true
}
return false
}
func (s *Server) onSessionExpire() { func (s *Server) onSessionExpire() {
if runtime.GOOS != "windows" { if runtime.GOOS != "windows" {
isUIActive := internal.CheckUIApp() isUIActive := internal.CheckUIApp()
@@ -1138,6 +1229,45 @@ func (s *Server) onSessionExpire() {
} }
} }
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
multiplier := defaultRetryMultiplier
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
// parse the multiplier from the environment variable string value to float64
value, err := strconv.ParseFloat(envValue, 64)
if err != nil {
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
} else {
multiplier = value
}
}
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: initialInterval,
RandomizationFactor: 1,
Multiplier: multiplier,
MaxInterval: maxInterval,
MaxElapsedTime: maxElapsedTime, // 14 days
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// parseEnvDuration parses the environment variable and returns the duration
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
if envValue := os.Getenv(envVar); envValue != "" {
if duration, err := time.ParseDuration(envValue); err == nil {
return duration
}
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
}
return defaultDuration
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{ pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{}, ManagementState: &proto.ManagementState{},
@@ -1252,121 +1382,3 @@ func sendTerminalNotification() error {
return wallCmd.Wait() return wallCmd.Wait()
} }
// AddProfile adds a new profile to the daemon.
func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.ProfileName == "" || msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{}, nil
}
// RemoveProfile removes a profile from the daemon.
func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
}
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{}, nil
}
// ListProfiles lists all profiles in the daemon.
func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided")
}
profiles, err := s.profileManager.ListProfiles(msg.Username)
if err != nil {
log.Errorf("failed to list profiles: %v", err)
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
response := &proto.ListProfilesResponse{
Profiles: make([]*proto.Profile, len(profiles)),
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Name: profile.Name,
IsActive: profile.IsActive,
}
}
return response, nil
}
// GetActiveProfile returns the active profile in the daemon.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
activeProfile, err := s.profileManager.GetActiveProfileState()
if err != nil {
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
return &proto.GetActiveProfileResponse{
ProfileName: activeProfile.Name,
Username: activeProfile.Username,
}, nil
}
// GetFeatures returns the features supported by the daemon.
func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
}
return features, nil
}
func (s *Server) checkProfilesDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.profilesDisabled {
return true
}
return false
}
func (s *Server) checkUpdateSettingsDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.updateSettingsDisabled {
return true
}
return false
}

View File

@@ -18,6 +18,7 @@ import (
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"

2
go.mod
View File

@@ -261,6 +261,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944

4
go.sum
View File

@@ -501,8 +501,8 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE
github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM= github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 h1:ZJwhKexMlK15B/Ld+1T8VYE2Mt1lk1kf2DlXr46EHcw= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=

View File

@@ -328,6 +328,45 @@ delete_auto_service_user() {
echo "$PARSED_RESPONSE" echo "$PARSED_RESPONSE"
} }
delete_default_zitadel_admin() {
INSTANCE_URL=$1
PAT=$2
# Search for the default zitadel-admin user
RESPONSE=$(
curl -sS -X POST "$INSTANCE_URL/management/v1/users/_search" \
-H "Authorization: Bearer $PAT" \
-H "Content-Type: application/json" \
-d '{
"queries": [
{
"userNameQuery": {
"userName": "zitadel-admin@",
"method": "TEXT_QUERY_METHOD_STARTS_WITH"
}
}
]
}'
)
DEFAULT_ADMIN_ID=$(echo "$RESPONSE" | jq -r '.result[0].id // empty')
if [ -n "$DEFAULT_ADMIN_ID" ] && [ "$DEFAULT_ADMIN_ID" != "null" ]; then
echo "Found default zitadel-admin user with ID: $DEFAULT_ADMIN_ID"
RESPONSE=$(
curl -sS -X DELETE "$INSTANCE_URL/management/v1/users/$DEFAULT_ADMIN_ID" \
-H "Authorization: Bearer $PAT" \
-H "Content-Type: application/json" \
)
PARSED_RESPONSE=$(echo "$RESPONSE" | jq -r '.details.changeDate // "deleted"')
handle_zitadel_request_response "$PARSED_RESPONSE" "delete_default_zitadel_admin" "$RESPONSE"
else
echo "Default zitadel-admin user not found: $RESPONSE"
fi
}
init_zitadel() { init_zitadel() {
echo -e "\nInitializing Zitadel with NetBird's applications\n" echo -e "\nInitializing Zitadel with NetBird's applications\n"
INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN" INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
@@ -346,6 +385,9 @@ init_zitadel() {
echo -n "Waiting for Zitadel to become ready " echo -n "Waiting for Zitadel to become ready "
wait_api "$INSTANCE_URL" "$PAT" wait_api "$INSTANCE_URL" "$PAT"
echo "Deleting default zitadel-admin user..."
delete_default_zitadel_admin "$INSTANCE_URL" "$PAT"
# create the zitadel project # create the zitadel project
echo "Creating new zitadel project" echo "Creating new zitadel project"
PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT") PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT")

View File

@@ -167,10 +167,22 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
// validatePolicy validates the policy and its rules. // validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
if policy.ID != "" { if policy.ID != "" {
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil { if err != nil {
return err return err
} }
// TODO: Refactor to support multiple rules per policy
existingRuleIDs := make(map[string]bool)
for _, rule := range existingPolicy.Rules {
existingRuleIDs[rule.ID] = true
}
for _, rule := range policy.Rules {
if rule.ID != "" && !existingRuleIDs[rule.ID] {
return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID)
}
}
} else { } else {
policy.ID = xid.New().String() policy.ID = xid.New().String()
policy.AccountID = accountID policy.AccountID = accountID

View File

@@ -81,6 +81,7 @@ func NewManager(ctx context.Context, serverURLs []string, peerID string, mtu uin
TokenStore: tokenStore, TokenStore: tokenStore,
PeerID: peerID, PeerID: peerID,
MTU: mtu, MTU: mtu,
ConnectionTimeout: defaultConnectionTimeout,
}, },
relayClients: make(map[string]*RelayTrack), relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List), onDisconnectedListeners: make(map[string]*list.List),

View File

@@ -14,10 +14,7 @@ import (
const ( const (
maxConcurrentServers = 7 maxConcurrentServers = 7
) defaultConnectionTimeout = 30 * time.Second
var (
connectionTimeout = 30 * time.Second
) )
type connResult struct { type connResult struct {
@@ -31,10 +28,11 @@ type ServerPicker struct {
ServerURLs atomic.Value ServerURLs atomic.Value
PeerID string PeerID string
MTU uint16 MTU uint16
ConnectionTimeout time.Duration
} }
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) ctx, cancel := context.WithTimeout(parentCtx, sp.ConnectionTimeout)
defer cancel() defer cancel()
totalServers := len(sp.ServerURLs.Load().([]string)) totalServers := len(sp.ServerURLs.Load().([]string))

View File

@@ -8,15 +8,15 @@ import (
) )
func TestServerPicker_UnavailableServers(t *testing.T) { func TestServerPicker_UnavailableServers(t *testing.T) {
connectionTimeout = 5 * time.Second timeout := 5 * time.Second
sp := ServerPicker{ sp := ServerPicker{
TokenStore: nil, TokenStore: nil,
PeerID: "test", PeerID: "test",
ConnectionTimeout: timeout,
} }
sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"}) sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) ctx, cancel := context.WithTimeout(context.Background(), timeout+1)
defer cancel() defer cancel()
go func() { go func() {

View File

@@ -0,0 +1,24 @@
package healthcheck
import (
"os"
"strconv"
log "github.com/sirupsen/logrus"
)
const (
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
)
func getAttemptThresholdFromEnv() int {
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
if err != nil {
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
return defaultAttemptThreshold
}
return int(threshold)
}
return defaultAttemptThreshold
}

View File

@@ -0,0 +1,36 @@
package healthcheck
import (
"os"
"testing"
)
//nolint:tenv
func TestGetAttemptThresholdFromEnv(t *testing.T) {
tests := []struct {
name string
envValue string
expected int
}{
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue == "" {
os.Unsetenv(defaultAttemptThresholdEnv)
} else {
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
}
result := getAttemptThresholdFromEnv()
if result != tt.expected {
t.Fatalf("Expected %d, got %d", tt.expected, result)
}
os.Unsetenv(defaultAttemptThresholdEnv)
})
}
}

View File

@@ -7,10 +7,15 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var ( const (
heartbeatTimeout = healthCheckInterval + 10*time.Second defaultHeartbeatTimeout = defaultHealthCheckInterval + 10*time.Second
) )
type ReceiverOptions struct {
HeartbeatTimeout time.Duration
AttemptThreshold int
}
// Receiver is a healthcheck receiver // Receiver is a healthcheck receiver
// It will listen for heartbeat and check if the heartbeat is not received in a certain time // It will listen for heartbeat and check if the heartbeat is not received in a certain time
// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work // If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
@@ -27,6 +32,23 @@ type Receiver struct {
// NewReceiver creates a new healthcheck receiver and start the timer in the background // NewReceiver creates a new healthcheck receiver and start the timer in the background
func NewReceiver(log *log.Entry) *Receiver { func NewReceiver(log *log.Entry) *Receiver {
opts := ReceiverOptions{
HeartbeatTimeout: defaultHeartbeatTimeout,
AttemptThreshold: getAttemptThresholdFromEnv(),
}
return NewReceiverWithOpts(log, opts)
}
func NewReceiverWithOpts(log *log.Entry, opts ReceiverOptions) *Receiver {
heartbeatTimeout := opts.HeartbeatTimeout
if heartbeatTimeout <= 0 {
heartbeatTimeout = defaultHeartbeatTimeout
}
attemptThreshold := opts.AttemptThreshold
if attemptThreshold <= 0 {
attemptThreshold = defaultAttemptThreshold
}
ctx, ctxCancel := context.WithCancel(context.Background()) ctx, ctxCancel := context.WithCancel(context.Background())
r := &Receiver{ r := &Receiver{
@@ -35,10 +57,10 @@ func NewReceiver(log *log.Entry) *Receiver {
ctx: ctx, ctx: ctx,
ctxCancel: ctxCancel, ctxCancel: ctxCancel,
heartbeat: make(chan struct{}, 1), heartbeat: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(), attemptThreshold: attemptThreshold,
} }
go r.waitForHealthcheck() go r.waitForHealthcheck(heartbeatTimeout)
return r return r
} }
@@ -55,7 +77,7 @@ func (r *Receiver) Stop() {
r.ctxCancel() r.ctxCancel()
} }
func (r *Receiver) waitForHealthcheck() { func (r *Receiver) waitForHealthcheck(heartbeatTimeout time.Duration) {
ticker := time.NewTicker(heartbeatTimeout) ticker := time.NewTicker(heartbeatTimeout)
defer ticker.Stop() defer ticker.Stop()
defer r.ctxCancel() defer r.ctxCancel()

View File

@@ -2,31 +2,18 @@ package healthcheck
import ( import (
"context" "context"
"fmt"
"os"
"sync"
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Mutex to protect global variable access in tests
var testMutex sync.Mutex
func TestNewReceiver(t *testing.T) { func TestNewReceiver(t *testing.T) {
testMutex.Lock()
originalTimeout := heartbeatTimeout
heartbeatTimeout = 5 * time.Second
testMutex.Unlock()
defer func() { opts := ReceiverOptions{
testMutex.Lock() HeartbeatTimeout: 5 * time.Second,
heartbeatTimeout = originalTimeout }
testMutex.Unlock() r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop() defer r.Stop()
select { select {
@@ -38,18 +25,10 @@ func TestNewReceiver(t *testing.T) {
} }
func TestNewReceiverNotReceive(t *testing.T) { func TestNewReceiverNotReceive(t *testing.T) {
testMutex.Lock() opts := ReceiverOptions{
originalTimeout := heartbeatTimeout HeartbeatTimeout: 1 * time.Second,
heartbeatTimeout = 1 * time.Second }
testMutex.Unlock() r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop() defer r.Stop()
select { select {
@@ -61,18 +40,10 @@ func TestNewReceiverNotReceive(t *testing.T) {
} }
func TestNewReceiverAck(t *testing.T) { func TestNewReceiverAck(t *testing.T) {
testMutex.Lock() opts := ReceiverOptions{
originalTimeout := heartbeatTimeout HeartbeatTimeout: 2 * time.Second,
heartbeatTimeout = 2 * time.Second }
testMutex.Unlock() r := NewReceiverWithOpts(log.WithContext(context.Background()), opts)
defer func() {
testMutex.Lock()
heartbeatTimeout = originalTimeout
testMutex.Unlock()
}()
r := NewReceiver(log.WithContext(context.Background()))
defer r.Stop() defer r.Stop()
r.Heartbeat() r.Heartbeat()
@@ -97,30 +68,19 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases { for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
testMutex.Lock() healthCheckInterval := 1 * time.Second
originalInterval := healthCheckInterval
originalTimeout := heartbeatTimeout
healthCheckInterval = 1 * time.Second
heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
testMutex.Unlock()
defer func() { opts := ReceiverOptions{
testMutex.Lock() HeartbeatTimeout: healthCheckInterval + 500*time.Millisecond,
healthCheckInterval = originalInterval AttemptThreshold: tc.threshold,
heartbeatTimeout = originalTimeout }
testMutex.Unlock()
}()
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
receiver := NewReceiver(log.WithField("test_name", tc.name)) receiver := NewReceiverWithOpts(log.WithField("test_name", tc.name), opts)
testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval testTimeout := opts.HeartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
if tc.resetCounterOnce { if tc.resetCounterOnce {
receiver.Heartbeat() receiver.Heartbeat()
t.Logf("reset counter once")
} }
select { select {
@@ -134,7 +94,6 @@ func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
} }
t.Fatalf("should have timed out before %s", testTimeout) t.Fatalf("should have timed out before %s", testTimeout)
} }
}) })
} }
} }

View File

@@ -2,8 +2,6 @@ package healthcheck
import ( import (
"context" "context"
"os"
"strconv"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -11,43 +9,69 @@ import (
const ( const (
defaultAttemptThreshold = 1 defaultAttemptThreshold = 1
defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
defaultHealthCheckInterval = 25 * time.Second
defaultHealthCheckTimeout = 20 * time.Second
) )
var ( type SenderOptions struct {
healthCheckInterval = 25 * time.Second HealthCheckInterval time.Duration
healthCheckTimeout = 20 * time.Second HealthCheckTimeout time.Duration
) AttemptThreshold int
}
// Sender is a healthcheck sender // Sender is a healthcheck sender
// It will send healthcheck signal to the receiver // It will send healthcheck signal to the receiver
// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work // If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
// It will also stop if the context is canceled // It will also stop if the context is canceled
type Sender struct { type Sender struct {
log *log.Entry
// HealthCheck is a channel to send health check signal to the peer // HealthCheck is a channel to send health check signal to the peer
HealthCheck chan struct{} HealthCheck chan struct{}
// Timeout is a channel to the health check signal is not received in a certain time // Timeout is a channel to the health check signal is not received in a certain time
Timeout chan struct{} Timeout chan struct{}
log *log.Entry
healthCheckInterval time.Duration
timeout time.Duration
ack chan struct{} ack chan struct{}
alive bool alive bool
attemptThreshold int attemptThreshold int
} }
// NewSender creates a new healthcheck sender func NewSenderWithOpts(log *log.Entry, opts SenderOptions) *Sender {
func NewSender(log *log.Entry) *Sender { if opts.HealthCheckInterval <= 0 {
opts.HealthCheckInterval = defaultHealthCheckInterval
}
if opts.HealthCheckTimeout <= 0 {
opts.HealthCheckTimeout = defaultHealthCheckTimeout
}
if opts.AttemptThreshold <= 0 {
opts.AttemptThreshold = defaultAttemptThreshold
}
hc := &Sender{ hc := &Sender{
log: log,
HealthCheck: make(chan struct{}, 1), HealthCheck: make(chan struct{}, 1),
Timeout: make(chan struct{}, 1), Timeout: make(chan struct{}, 1),
log: log,
healthCheckInterval: opts.HealthCheckInterval,
timeout: opts.HealthCheckInterval + opts.HealthCheckTimeout,
ack: make(chan struct{}, 1), ack: make(chan struct{}, 1),
attemptThreshold: getAttemptThresholdFromEnv(), attemptThreshold: opts.AttemptThreshold,
} }
return hc return hc
} }
// NewSender creates a new healthcheck sender
func NewSender(log *log.Entry) *Sender {
opts := SenderOptions{
HealthCheckInterval: defaultHealthCheckInterval,
HealthCheckTimeout: defaultHealthCheckTimeout,
AttemptThreshold: getAttemptThresholdFromEnv(),
}
return NewSenderWithOpts(log, opts)
}
// OnHCResponse sends an acknowledgment signal to the sender // OnHCResponse sends an acknowledgment signal to the sender
func (hc *Sender) OnHCResponse() { func (hc *Sender) OnHCResponse() {
select { select {
@@ -57,10 +81,10 @@ func (hc *Sender) OnHCResponse() {
} }
func (hc *Sender) StartHealthCheck(ctx context.Context) { func (hc *Sender) StartHealthCheck(ctx context.Context) {
ticker := time.NewTicker(healthCheckInterval) ticker := time.NewTicker(hc.healthCheckInterval)
defer ticker.Stop() defer ticker.Stop()
timeoutTicker := time.NewTicker(hc.getTimeoutTime()) timeoutTicker := time.NewTicker(hc.timeout)
defer timeoutTicker.Stop() defer timeoutTicker.Stop()
defer close(hc.HealthCheck) defer close(hc.HealthCheck)
@@ -92,19 +116,3 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) {
} }
} }
} }
func (hc *Sender) getTimeoutTime() time.Duration {
return healthCheckInterval + healthCheckTimeout
}
func getAttemptThresholdFromEnv() int {
if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
if err != nil {
log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
return defaultAttemptThreshold
}
return int(threshold)
}
return defaultAttemptThreshold
}

View File

@@ -2,26 +2,23 @@ package healthcheck
import ( import (
"context" "context"
"fmt"
"os"
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
func TestMain(m *testing.M) { var (
// override the health check interval to speed up the test testOpts = SenderOptions{
healthCheckInterval = 2 * time.Second HealthCheckInterval: 2 * time.Second,
healthCheckTimeout = 100 * time.Millisecond HealthCheckTimeout: 100 * time.Millisecond,
code := m.Run()
os.Exit(code)
} }
)
func TestNewHealthPeriod(t *testing.T) { func TestNewHealthPeriod(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
hc := NewSender(log.WithContext(ctx)) hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
go hc.StartHealthCheck(ctx) go hc.StartHealthCheck(ctx)
iterations := 0 iterations := 0
@@ -32,7 +29,7 @@ func TestNewHealthPeriod(t *testing.T) {
hc.OnHCResponse() hc.OnHCResponse()
case <-hc.Timeout: case <-hc.Timeout:
t.Fatalf("health check is timed out") t.Fatalf("health check is timed out")
case <-time.After(healthCheckInterval + 100*time.Millisecond): case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received") t.Fatalf("health check not received")
} }
} }
@@ -41,19 +38,19 @@ func TestNewHealthPeriod(t *testing.T) {
func TestNewHealthFailed(t *testing.T) { func TestNewHealthFailed(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
hc := NewSender(log.WithContext(ctx)) hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
go hc.StartHealthCheck(ctx) go hc.StartHealthCheck(ctx)
select { select {
case <-hc.Timeout: case <-hc.Timeout:
case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond): case <-time.After(testOpts.HealthCheckInterval + testOpts.HealthCheckTimeout + 100*time.Millisecond):
t.Fatalf("health check is not timed out") t.Fatalf("health check is not timed out")
} }
} }
func TestNewHealthcheckStop(t *testing.T) { func TestNewHealthcheckStop(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
hc := NewSender(log.WithContext(ctx)) hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
go hc.StartHealthCheck(ctx) go hc.StartHealthCheck(ctx)
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -78,7 +75,7 @@ func TestNewHealthcheckStop(t *testing.T) {
func TestTimeoutReset(t *testing.T) { func TestTimeoutReset(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
hc := NewSender(log.WithContext(ctx)) hc := NewSenderWithOpts(log.WithContext(ctx), testOpts)
go hc.StartHealthCheck(ctx) go hc.StartHealthCheck(ctx)
iterations := 0 iterations := 0
@@ -89,7 +86,7 @@ func TestTimeoutReset(t *testing.T) {
hc.OnHCResponse() hc.OnHCResponse()
case <-hc.Timeout: case <-hc.Timeout:
t.Fatalf("health check is timed out") t.Fatalf("health check is timed out")
case <-time.After(healthCheckInterval + 100*time.Millisecond): case <-time.After(testOpts.HealthCheckInterval + 100*time.Millisecond):
t.Fatalf("health check not received") t.Fatalf("health check not received")
} }
} }
@@ -118,19 +115,16 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
for _, tc := range testsCases { for _, tc := range testsCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
originalInterval := healthCheckInterval opts := SenderOptions{
originalTimeout := healthCheckTimeout HealthCheckInterval: 1 * time.Second,
healthCheckInterval = 1 * time.Second HealthCheckTimeout: 500 * time.Millisecond,
healthCheckTimeout = 500 * time.Millisecond AttemptThreshold: tc.threshold,
}
//nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
defer os.Unsetenv(defaultAttemptThresholdEnv)
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
sender := NewSender(log.WithField("test_name", tc.name)) sender := NewSenderWithOpts(log.WithField("test_name", tc.name), opts)
senderExit := make(chan struct{}) senderExit := make(chan struct{})
go func() { go func() {
sender.StartHealthCheck(ctx) sender.StartHealthCheck(ctx)
@@ -155,7 +149,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
} }
}() }()
testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval testTimeout := (opts.HealthCheckInterval+opts.HealthCheckTimeout)*time.Duration(tc.threshold) + opts.HealthCheckInterval
select { select {
case <-sender.Timeout: case <-sender.Timeout:
@@ -175,39 +169,7 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Fatalf("sender did not exit in time") t.Fatalf("sender did not exit in time")
} }
healthCheckInterval = originalInterval
healthCheckTimeout = originalTimeout
}) })
} }
} }
//nolint:tenv
func TestGetAttemptThresholdFromEnv(t *testing.T) {
tests := []struct {
name string
envValue string
expected int
}{
{"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
{"Custom attempt threshold when env is set to a valid integer", "3", 3},
{"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.envValue == "" {
os.Unsetenv(defaultAttemptThresholdEnv)
} else {
os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
}
result := getAttemptThresholdFromEnv()
if result != tt.expected {
t.Fatalf("Expected %d, got %d", tt.expected, result)
}
os.Unsetenv(defaultAttemptThresholdEnv)
})
}
}