Compare commits

...

18 Commits

Author SHA1 Message Date
Vlad
25ed58328a [management] fix network map dns filter (#4547) 2025-09-25 16:29:14 +02:00
hakansa
644ed4b934 [client] Add WireGuard interface lifecycle monitoring (#4370)
* [client] Add WireGuard interface lifecycle monitoring
2025-09-25 15:36:26 +07:00
Pascal Fischer
58faa341d2 [management] Add logs for update channel (#4527) 2025-09-23 12:06:10 +02:00
Viktor Liu
5853b5553c [client] Skip interface for route lookup if it doesn't exist (#4524) 2025-09-22 14:32:00 +02:00
Zoltan Papp
998fb30e1e [client] Check the client status in the earlier phase (#4509)
This PR improves the NetBird client's status checking mechanism by implementing earlier detection of client state changes and better handling of connection lifecycle management. The key improvements focus on:

  • Enhanced status detection - Added waitForReady option to StatusRequest for improved client status handling
  • Better connection management - Improved context handling for signal and management gRPC connections• Reduced connection timeouts - Increased gRPC dial timeout from 3 to 10 seconds for better reliability
  • Cleaner error handling - Enhanced error propagation and context cancellation in retry loops

  Key Changes

  Core Status Improvements:
  - Added waitForReady optional field to StatusRequest proto (daemon.proto:190)
  - Enhanced status checking logic to detect client state changes earlier in the connection process
  - Improved handling of client permanent exit scenarios from retry loops

  Connection & Context Management:
  - Fixed context cancellation in management and signal client retry mechanisms
  - Added proper context propagation for Login operations
  - Enhanced gRPC connection handling with better timeout management

  Error Handling & Cleanup:
  - Moved feedback channels to upper layers for better separation of concerns
  - Improved error handling patterns throughout the client server implementation
  - Fixed synchronization issues and removed debug logging
2025-09-20 22:14:01 +02:00
Maycon Santos
e254b4cde5 [misc] Update SIGN_PIPE_VER to version 0.0.23 (#4521) 2025-09-20 10:24:04 +02:00
Zoltan Papp
ead1c618ba [client] Do not run up cmd if not needed in docker (#4508)
optimizes the NetBird client startup process by avoiding unnecessary login commands when the peer is already authenticated. The changes increase the default login timeout and expand the log message patterns used to detect successful authentication.

- Increased default login timeout from 1 to 5 seconds for more reliable authentication detection
- Enhanced log pattern matching to detect both registration and ready states
- Added extended regex support for more flexible pattern matching
2025-09-20 10:00:18 +02:00
Viktor Liu
55126f990c [client] Use native windows sock opts to avoid routing loops (#4314)
- Move `util/grpc` and `util/net` to `client` so `internal` packages can be accessed
 - Add methods to return the next best interface after the NetBird interface.
- Use `IP_UNICAST_IF` sock opt to force the outgoing interface for the NetBird `net.Dialer` and `net.ListenerConfig` to avoid routing loops. The interface is picked by the new route lookup method.
- Some refactoring to avoid import cycles
- Old behavior is available through `NB_USE_LEGACY_ROUTING=true` env var
2025-09-20 09:31:04 +02:00
Misha Bragin
90577682e4 Add a new product demo video (#4520) 2025-09-19 13:06:44 +02:00
Bethuel Mmbaga
dc30dcacce [management] Filter DNS records to include only peers to connect (#4517)
DNS record filtering to only include peers that a peer can connect to, reducing unnecessary DNS data in the peer's network map.

- Adds a new `filterZoneRecordsForPeers` function to filter DNS records based on peer connectivity
- Modifies `GetPeerNetworkMap` to use filtered DNS records instead of all records in the custom zone
- Includes comprehensive test coverage for the new filtering functionality
2025-09-18 18:57:07 +02:00
Diego Romar
2c87fa6236 [android] Add OnLoginSuccess callback to URLOpener interface (#4492)
The callback will be fired once login -> internal.Login
completes without errors
2025-09-18 15:07:42 +02:00
hakansa
ec8d83ade4 [client] [UI] Down & Up NetBird Async When Settings Updated
[client] [UI] Down & Up NetBird Async When Settings Updated
2025-09-18 18:13:29 +07:00
Bethuel Mmbaga
3130cce72d [management] Add rule ID validation for policy updates (#4499) 2025-09-15 21:08:16 +03:00
Zoltan Papp
bd23ab925e [client] Fix ICE latency handling (#4501)
The GetSelectedCandidatePair() does not carry the latency information.
2025-09-15 15:08:53 +02:00
Zoltan Papp
0c6f671a7c Refactor healthcheck sender and receiver to use configurable options (#4433) 2025-09-12 09:31:03 +02:00
Bethuel Mmbaga
cf7f6c355f [misc] Remove default zitadel admin user in deployment script (#4482)
* Delete default zitadel-admin user during initialization

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-09-11 21:20:10 +02:00
Zoltan Papp
47e64d72db [client] Fix client status check (#4474)
The client status is not enough to protect the RPC calls from concurrency issues, because it is handled internally in the client in an asynchronous way.
2025-09-11 16:21:09 +02:00
Zoltan Papp
9e81e782e5 [client] Fix/v4 stun routing (#4430)
Deduplicate STUN package sending.
Originally, because every peer shared the same UDP address, the library could not distinguish which STUN message was associated with which candidate. As a result, the Pion library responded from all candidates for every STUN message.
2025-09-11 10:08:54 +02:00
128 changed files with 2350 additions and 1283 deletions

View File

@@ -217,7 +217,7 @@ jobs:
- arch: "386"
raceFlag: ""
- arch: "amd64"
raceFlag: ""
raceFlag: "-race"
runs-on: ubuntu-22.04
steps:
- name: Install Go

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.22"
SIGN_PIPE_VER: "v0.0.23"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"

View File

@@ -1,3 +1,4 @@
<div align="center">
<br/>
<br/>
@@ -52,7 +53,7 @@
### Open Source Network Security in a Single Platform
<img width="1188" alt="centralized-network-management 1" src="https://github.com/user-attachments/assets/c28cc8e4-15d2-4d2f-bb97-a6433db39d56" />
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)

View File

@@ -18,7 +18,7 @@ ENV \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -19,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/client/net"
)
// ConnectionListener export internal Listener for mobile

View File

@@ -33,6 +33,7 @@ type ErrListener interface {
// the backend want to show an url for the user
type URLOpener interface {
Open(string)
OnLoginSuccess()
}
// Auth can register or login new client
@@ -181,6 +182,11 @@ func (a *Auth) login(urlOpener URLOpener) error {
err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)
if err == nil {
go urlOpener.OnLoginSuccess()
}
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}

View File

@@ -27,7 +27,7 @@ var downCmd = &cobra.Command{
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*20)
defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr)

View File

@@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string {
// DialClientGRPCServer returns client connection to the daemon server.
func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*3)
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
return grpc.DialContext(

View File

@@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
status, err := client.Status(ctx, &proto.StatusRequest{
WaitForReady: func() *bool { b := true; return &b }(),
})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}

View File

@@ -135,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{}, 1)
run := make(chan struct{})
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {

View File

@@ -12,7 +12,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (

View File

@@ -19,7 +19,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// constants needed to manage and create iptable rules

View File

@@ -14,7 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
func isIptablesSupported() bool {

View File

@@ -16,7 +16,7 @@ import (
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (

View File

@@ -22,7 +22,7 @@ import (
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (

View File

@@ -20,8 +20,9 @@ import (
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/util/embeddedroots"
nbnet "github.com/netbirdio/netbird/util/net"
)
func WithCustomDialer() grpc.DialOption {
@@ -57,7 +58,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(b, ctx)
}
func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
if tlsEnabled {
certPool, err := x509.SystemCertPool()
@@ -71,7 +72,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) {
}))
}
connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
conn, err := grpc.DialContext(

View File

@@ -3,7 +3,7 @@ package bind
import (
wireguard "golang.zx2c4.com/wireguard/conn"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)

View File

@@ -15,8 +15,9 @@ import (
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type RecvMessage struct {
@@ -44,7 +45,7 @@ type ICEBind struct {
RecvChan chan RecvMessage
transportNet transport.Net
filterFn FilterFn
filterFn udpmux.FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
@@ -54,13 +55,13 @@ type ICEBind struct {
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
address wgaddr.Address
mtu uint16
activityRecorder *ActivityRecorder
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
@@ -115,7 +116,7 @@ func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
func (s *ICEBind) GetICEMux() (*udpmux.UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
@@ -158,8 +159,8 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
s.udpMux = udpmux.NewUniversalUDPMuxDefault(
udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(conn),
Net: s.transportNet,
FilterFn: s.filterFn,

View File

@@ -1,7 +0,0 @@
//go:build ios
package bind
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -17,8 +17,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/monotime"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
@@ -409,7 +409,7 @@ func toBytes(s string) (int64, error) {
}
func getFwmark() int {
if nbnet.AdvancedRouting() {
if nbnet.AdvancedRouting() && runtime.GOOS == "linux" {
return nbnet.ControlPlaneMark
}
return 0

View File

@@ -7,14 +7,14 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
MTU() uint16

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -29,7 +30,7 @@ type WGTunDevice struct {
name string
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -88,7 +89,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
}
return t.configurer, nil
}
func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *WGTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -26,7 +27,7 @@ type TunDevice struct {
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -71,7 +72,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -28,7 +29,7 @@ type TunDevice struct {
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -83,7 +84,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err

View File

@@ -12,11 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/sharedsock"
nbnet "github.com/netbirdio/netbird/util/net"
)
type TunKernelDevice struct {
@@ -31,9 +31,9 @@ type TunKernelDevice struct {
link *wgLink
udpMuxConn net.PacketConn
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
filterFn bind.FilterFn
filterFn udpmux.FilterFn
}
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
@@ -79,7 +79,7 @@ func (t *TunKernelDevice) Create() (WGConfigurer, error) {
return configurer, nil
}
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.udpMux != nil {
return t.udpMux, nil
}
@@ -101,19 +101,14 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err
}
var udpConn net.PacketConn = rawSock
if !nbnet.AdvancedRouting() {
udpConn = nbnet.WrapPacketConn(rawSock)
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: udpConn,
bindParams := udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(rawSock),
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
MTU: t.mtu,
}
mux := bind.NewUniversalUDPMuxDefault(bindParams)
mux := udpmux.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
t.udpMuxConn = rawSock
t.udpMux = mux

View File

@@ -10,8 +10,9 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type TunNetstackDevice struct {
@@ -26,7 +27,7 @@ type TunNetstackDevice struct {
device *device.Device
filteredDevice *FilteredDevice
nsTun *nbnetstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
net *netstack.Net
@@ -80,7 +81,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunNetstackDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -25,7 +26,7 @@ type USPDevice struct {
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -74,7 +75,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *USPDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -29,7 +30,7 @@ type TunDevice struct {
device *device.Device
nativeTunDevice *tun.NativeTun
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
configurer WGConfigurer
}
@@ -104,7 +105,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
func (t *TunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err

View File

@@ -5,14 +5,14 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address
MTU() uint16

View File

@@ -16,9 +16,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
@@ -61,7 +61,7 @@ type WGIFaceOpts struct {
MTU uint16
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn bind.FilterFn
FilterFn udpmux.FilterFn
DisableDNS bool
}
@@ -114,7 +114,7 @@ func (r *WGIface) ToInterface() *net.Interface {
// Up configures a Wireguard interface
// The interface must exist before calling this method (e.g. call interface.Create() before)
func (w *WGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
func (w *WGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
w.mu.Lock()
defer w.mu.Unlock()

View File

@@ -1,4 +1,4 @@
package bind
package udpmux
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
@@ -16,11 +16,12 @@ import (
)
type udpMuxedConnParams struct {
Mux *UDPMuxDefault
Mux *SingleSocketUDPMux
AddrPool *sync.Pool
Key string
LocalAddr net.Addr
Logger logging.LeveledLogger
CandidateID string
}
// udpMuxedConn represents a logical packet conn for a single remote as identified by ufrag
@@ -119,6 +120,10 @@ func (c *udpMuxedConn) Close() error {
return err
}
func (c *udpMuxedConn) GetCandidateID() string {
return c.params.CandidateID
}
func (c *udpMuxedConn) isClosed() bool {
select {
case <-c.closedChan:

View File

@@ -0,0 +1,64 @@
// Package udpmux provides a custom implementation of a UDP multiplexer
// that allows multiple logical ICE connections to share a single underlying
// UDP socket. This is based on Pion's ICE library, with modifications for
// NetBird's requirements.
//
// # Background
//
// In WebRTC and NAT traversal scenarios, ICE (Interactive Connectivity
// Establishment) is responsible for discovering candidate network paths
// and maintaining connectivity between peers. Each ICE connection
// normally requires a dedicated UDP socket. However, using one socket
// per candidate can be inefficient and difficult to manage.
//
// This package introduces SingleSocketUDPMux, which allows multiple ICE
// candidate connections (muxed connections) to share a single UDP socket.
// It handles demultiplexing of packets based on ICE ufrag values, STUN
// attributes, and candidate IDs.
//
// # Usage
//
// The typical flow is:
//
// 1. Create a UDP socket (net.PacketConn).
// 2. Construct Params with the socket and optional logger/net stack.
// 3. Call NewSingleSocketUDPMux(params).
// 4. For each ICE candidate ufrag, call GetConn(ufrag, addr, candidateID)
// to obtain a logical PacketConn.
// 5. Use the returned PacketConn just like a normal UDP connection.
//
// # STUN Message Routing Logic
//
// When a STUN packet arrives, the mux decides which connection should
// receive it using this routing logic:
//
// Primary Routing: Candidate Pair ID
// - Extract the candidate pair ID from the STUN message using
// ice.CandidatePairIDFromSTUN(msg)
// - The target candidate is the locally generated candidate that
// corresponds to the connection that should handle this STUN message
// - If found, use the target candidate ID to lookup the specific
// connection in candidateConnMap
// - Route the message directly to that connection
//
// Fallback Routing: Broadcasting
// When candidate pair ID is not available or lookup fails:
// - Collect connections from addressMap based on source address
// - Find connection using username attribute (ufrag) from STUN message
// - Remove duplicate connections from the list
// - Send the STUN message to all collected connections
//
// # Peer Reflexive Candidate Discovery
//
// When a remote peer sends a STUN message from an unknown source address
// (from a candidate that has not been exchanged via signal), the ICE
// library will:
// - Generate a new peer reflexive candidate for this source address
// - Extract or assign a candidate ID based on the STUN message attributes
// - Create a mapping between the new peer reflexive candidate ID and
// the appropriate local connection
//
// This discovery mechanism ensures that STUN messages from newly discovered
// peer reflexive candidates can be properly routed to the correct local
// connection without requiring fallback broadcasting.
package udpmux

View File

@@ -1,4 +1,4 @@
package bind
package udpmux
import (
"fmt"
@@ -22,9 +22,9 @@ import (
const receiveMTU = 8192
// UDPMuxDefault is an implementation of the interface
type UDPMuxDefault struct {
params UDPMuxParams
// SingleSocketUDPMux is an implementation of the interface
type SingleSocketUDPMux struct {
params Params
closedChan chan struct{}
closeOnce sync.Once
@@ -32,6 +32,9 @@ type UDPMuxDefault struct {
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
connsIPv4, connsIPv6 map[string]*udpMuxedConn
// candidateConnMap maps local candidate IDs to their corresponding connection.
candidateConnMap map[string]*udpMuxedConn
addressMapMu sync.RWMutex
addressMap map[string][]*udpMuxedConn
@@ -46,8 +49,8 @@ type UDPMuxDefault struct {
const maxAddrSize = 512
// UDPMuxParams are parameters for UDPMux.
type UDPMuxParams struct {
// Params are parameters for UDPMux.
type Params struct {
Logger logging.LeveledLogger
UDPConn net.PacketConn
@@ -147,17 +150,18 @@ func isZeros(ip net.IP) bool {
return true
}
// NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
// NewSingleSocketUDPMux creates an implementation of UDPMux
func NewSingleSocketUDPMux(params Params) *SingleSocketUDPMux {
if params.Logger == nil {
params.Logger = getLogger()
}
mux := &UDPMuxDefault{
mux := &SingleSocketUDPMux{
addressMap: map[string][]*udpMuxedConn{},
params: params,
connsIPv4: make(map[string]*udpMuxedConn),
connsIPv6: make(map[string]*udpMuxedConn),
candidateConnMap: make(map[string]*udpMuxedConn),
closedChan: make(chan struct{}, 1),
pool: &sync.Pool{
New: func() interface{} {
@@ -171,15 +175,15 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return mux
}
func (m *UDPMuxDefault) updateLocalAddresses() {
func (m *SingleSocketUDPMux) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
// with SingleSocketUDPMux, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("SingleSocketUDPMux should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
@@ -216,13 +220,13 @@ func (m *UDPMuxDefault) updateLocalAddresses() {
m.mu.Unlock()
}
// LocalAddr returns the listening address of this UDPMuxDefault
func (m *UDPMuxDefault) LocalAddr() net.Addr {
// LocalAddr returns the listening address of this SingleSocketUDPMux
func (m *SingleSocketUDPMux) LocalAddr() net.Addr {
return m.params.UDPConn.LocalAddr()
}
// GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
func (m *SingleSocketUDPMux) GetListenAddresses() []net.Addr {
m.updateLocalAddresses()
m.mu.Lock()
@@ -236,7 +240,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// GetConn returns a PacketConn given the connection's ufrag and network address
// creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
func (m *SingleSocketUDPMux) GetConn(ufrag string, addr net.Addr, candidateID string) (net.PacketConn, error) {
// don't check addr for mux using unspecified address
m.mu.Lock()
lenLocalAddrs := len(m.localAddrsForUnspecified)
@@ -260,12 +264,14 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
return conn, nil
}
c := m.createMuxedConn(ufrag)
c := m.createMuxedConn(ufrag, candidateID)
go func() {
<-c.CloseChannel()
m.RemoveConnByUfrag(ufrag)
}()
m.candidateConnMap[candidateID] = c
if isIPv6 {
m.connsIPv6[ufrag] = c
} else {
@@ -276,7 +282,7 @@ func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, er
}
// RemoveConnByUfrag stops and removes the muxed packet connection
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
func (m *SingleSocketUDPMux) RemoveConnByUfrag(ufrag string) {
removedConns := make([]*udpMuxedConn, 0, 2)
// Keep lock section small to avoid deadlock with conn lock
@@ -284,10 +290,12 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
if c, ok := m.connsIPv4[ufrag]; ok {
delete(m.connsIPv4, ufrag)
removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
}
if c, ok := m.connsIPv6[ufrag]; ok {
delete(m.connsIPv6, ufrag)
removedConns = append(removedConns, c)
delete(m.candidateConnMap, c.GetCandidateID())
}
m.mu.Unlock()
@@ -314,7 +322,7 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
}
// IsClosed returns true if the mux had been closed
func (m *UDPMuxDefault) IsClosed() bool {
func (m *SingleSocketUDPMux) IsClosed() bool {
select {
case <-m.closedChan:
return true
@@ -324,7 +332,7 @@ func (m *UDPMuxDefault) IsClosed() bool {
}
// Close the mux, no further connections could be created
func (m *UDPMuxDefault) Close() error {
func (m *SingleSocketUDPMux) Close() error {
var err error
m.closeOnce.Do(func() {
m.mu.Lock()
@@ -347,11 +355,11 @@ func (m *UDPMuxDefault) Close() error {
return err
}
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
func (m *SingleSocketUDPMux) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
return m.params.UDPConn.WriteTo(buf, rAddr)
}
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
func (m *SingleSocketUDPMux) registerConnForAddress(conn *udpMuxedConn, addr string) {
if m.IsClosed() {
return
}
@@ -368,81 +376,109 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
}
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
func (m *SingleSocketUDPMux) createMuxedConn(key string, candidateID string) *udpMuxedConn {
c := newUDPMuxedConn(&udpMuxedConnParams{
Mux: m,
Key: key,
AddrPool: m.pool,
LocalAddr: m.LocalAddr(),
Logger: m.params.Logger,
CandidateID: candidateID,
})
return c
}
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
func (m *SingleSocketUDPMux) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
remoteAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
}
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
m.addressMapMu.RLock()
// Try to route to specific candidate connection first
if conn := m.findCandidateConnection(msg); conn != nil {
return conn.writePacket(msg.Raw, remoteAddr)
}
// Fallback: route to all possible connections
return m.forwardToAllConnections(msg, addr, remoteAddr)
}
// findCandidateConnection attempts to find the specific connection for a STUN message
func (m *SingleSocketUDPMux) findCandidateConnection(msg *stun.Message) *udpMuxedConn {
candidatePairID, ok, err := ice.CandidatePairIDFromSTUN(msg)
if err != nil {
return nil
} else if !ok {
return nil
}
m.mu.Lock()
defer m.mu.Unlock()
conn, exists := m.candidateConnMap[candidatePairID.TargetCandidateID()]
if !exists {
return nil
}
return conn
}
// forwardToAllConnections forwards STUN message to all relevant connections
func (m *SingleSocketUDPMux) forwardToAllConnections(msg *stun.Message, addr net.Addr, remoteAddr *net.UDPAddr) error {
var destinationConnList []*udpMuxedConn
// Add connections from address map
m.addressMapMu.RLock()
if storedConns, ok := m.addressMap[addr.String()]; ok {
destinationConnList = append(destinationConnList, storedConns...)
}
m.addressMapMu.RUnlock()
var isIPv6 bool
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
isIPv6 = true
if conn, ok := m.findConnectionByUsername(msg, addr); ok {
// If we have already seen this address dispatch to the appropriate destination
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
// We will then forward STUN packets to each of these connections.
if !m.connectionExists(conn, destinationConnList) {
destinationConnList = append(destinationConnList, conn)
}
}
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
// However, we can take a username attribute from the STUN message which contains ufrag.
// We can use ufrag to identify the destination conn to route packet to.
attr, stunAttrErr := msg.Get(stun.AttrUsername)
if stunAttrErr == nil {
ufrag := strings.Split(string(attr), ":")[0]
m.mu.Lock()
destinationConn := m.connsIPv4[ufrag]
if isIPv6 {
destinationConn = m.connsIPv6[ufrag]
}
if destinationConn != nil {
exists := false
for _, conn := range destinationConnList {
if conn.params.Key == destinationConn.params.Key {
exists = true
break
}
}
if !exists {
destinationConnList = append(destinationConnList, destinationConn)
}
}
m.mu.Unlock()
}
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
// It will be discarded by the further ICE candidate logic if so.
// Forward to all found connections
for _, conn := range destinationConnList {
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
log.Errorf("could not write packet: %v", err)
}
}
return nil
}
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
// findConnectionByUsername finds connection using username attribute from STUN message
func (m *SingleSocketUDPMux) findConnectionByUsername(msg *stun.Message, addr net.Addr) (*udpMuxedConn, bool) {
attr, err := msg.Get(stun.AttrUsername)
if err != nil {
return nil, false
}
ufrag := strings.Split(string(attr), ":")[0]
isIPv6 := isIPv6Address(addr)
m.mu.Lock()
defer m.mu.Unlock()
return m.getConn(ufrag, isIPv6)
}
// connectionExists checks if a connection already exists in the list
func (m *SingleSocketUDPMux) connectionExists(target *udpMuxedConn, conns []*udpMuxedConn) bool {
for _, conn := range conns {
if conn.params.Key == target.params.Key {
return true
}
}
return false
}
func (m *SingleSocketUDPMux) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
if isIPv6 {
val, ok = m.connsIPv6[ufrag]
} else {
@@ -451,6 +487,13 @@ func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, o
return
}
func isIPv6Address(addr net.Addr) bool {
if udpAddr, ok := addr.(*net.UDPAddr); ok {
return udpAddr.IP.To4() == nil
}
return false
}
type bufferHolder struct {
buf []byte
}

View File

@@ -1,12 +1,12 @@
//go:build !ios
package bind
package udpmux
import (
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// Kernel mode: direct nbnet.PacketConn (SharedSocket wrapped with nbnet)
if conn, ok := m.params.UDPConn.(*nbnet.PacketConn); ok {
conn.RemoveAddress(addr)

View File

@@ -0,0 +1,7 @@
//go:build ios
package udpmux
func (m *SingleSocketUDPMux) notifyAddressRemoval(addr string) {
// iOS doesn't support nbnet hooks, so this is a no-op
}

View File

@@ -1,4 +1,4 @@
package bind
package udpmux
/*
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
@@ -29,7 +29,7 @@ type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
*UDPMuxDefault
*SingleSocketUDPMux
params UniversalUDPMuxParams
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
@@ -72,12 +72,12 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
address: params.WGAddress,
}
udpMuxParams := UDPMuxParams{
udpMuxParams := Params{
Logger: params.Logger,
UDPConn: m.params.UDPConn,
Net: m.params.Net,
}
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
m.SingleSocketUDPMux = NewSingleSocketUDPMux(udpMuxParams)
return m
}
@@ -211,8 +211,8 @@ func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
// and return a unique connection per server.
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr, candidateID string) (net.PacketConn, error) {
return m.SingleSocketUDPMux.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr, candidateID)
}
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
@@ -233,7 +233,7 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
}
return nil
}
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
return m.SingleSocketUDPMux.HandleSTUNMessage(msg, addr)
}
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.

View File

@@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (

View File

@@ -34,7 +34,7 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version"
)
@@ -280,15 +280,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
if runningChan != nil {
select {
case runningChan <- struct{}{}:
default:
}
close(runningChan)
runningChan = nil
}
<-engineCtx.Done()

View File

@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type ServiceViaMemory struct {

View File

@@ -10,7 +10,7 @@ import (
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/peer"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type upstreamResolver struct {

View File

@@ -29,9 +29,9 @@ import (
"github.com/netbirdio/netbird/client/firewall"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
@@ -166,7 +166,7 @@ type Engine struct {
wgInterface WGIface
udpMux *bind.UniversalUDPMuxDefault
udpMux *udpmux.UniversalUDPMuxDefault
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
@@ -198,6 +198,10 @@ type Engine struct {
latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
}
// Peer is an instance of the Connection Peer
@@ -341,6 +345,9 @@ func (e *Engine) Stop() error {
log.Errorf("failed to persist state: %v", err)
}
// Stop WireGuard interface monitor and wait for it to exit
e.wgIfaceMonitorWg.Wait()
return nil
}
@@ -446,6 +453,8 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("up wg interface: %w", err)
}
// if inbound conns are blocked there is no need to create the ACL manager
if e.firewall != nil && !e.config.BlockInbound {
e.acl = acl.NewDefaultManager(e.firewall)
@@ -461,7 +470,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
@@ -477,6 +486,22 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.wgIfaceMonitorWg.Add(1)
go func() {
defer e.wgIfaceMonitorWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.restartEngine()
} else if err != nil {
log.Warnf("WireGuard interface monitor: %s", err)
}
}()
return nil
}
@@ -1326,7 +1351,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMux: e.udpMux.SingleSocketUDPMux,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
},

View File

@@ -26,10 +26,11 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns"
@@ -84,7 +85,7 @@ type MockWGIface struct {
NameFunc func() string
AddressFunc func() wgaddr.Address
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpFunc func() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
@@ -134,7 +135,7 @@ func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc()
}
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
func (m *MockWGIface) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return m.UpFunc()
}
@@ -413,7 +414,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil {
t.Fatal(err)
}
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
engine.udpMux = udpmux.NewUniversalUDPMuxDefault(udpmux.UniversalUDPMuxParams{UDPConn: conn, MTU: 1280})
engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
engine.connMgr = NewConnMgr(engine.config, engine.statusRecorder, engine.peerStore, wgIface)

View File

@@ -9,9 +9,9 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/monotime"
@@ -24,7 +24,7 @@ type wgIfaceBase interface {
Name() string
Address() wgaddr.Address
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
Up() (*udpmux.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error

View File

@@ -14,7 +14,7 @@ import (
"github.com/ti-mo/netfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
const defaultChannelSize = 100

View File

@@ -9,11 +9,10 @@ import (
"time"
"github.com/pion/ice/v4"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/internal/peer/conntype"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/stdnet"
@@ -55,10 +54,6 @@ type WorkerICE struct {
sessionID ICESessionID
muxAgent sync.Mutex
StunTurn []*stun.URI
sentExtraSrflx bool
localUfrag string
localPwd string
@@ -139,7 +134,6 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.muxAgent.Unlock()
return
}
w.sentExtraSrflx = false
w.agent = agent
w.agentDialerCancel = dialerCancel
w.agentConnecting = true
@@ -166,6 +160,21 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
w.log.Errorf("error while handling remote candidate")
return
}
if shouldAddExtraCandidate(candidate) {
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
if err := w.agent.AddRemoteCandidate(extraSrflx); err != nil {
w.log.Errorf("error while handling remote candidate")
return
}
}
}
func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
@@ -209,7 +218,9 @@ func (w *WorkerICE) reCreateAgent(dialerCancel context.CancelFunc, candidates []
return nil, err
}
if err := agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair); err != nil {
if err := agent.OnSelectedCandidatePairChange(func(c1, c2 ice.Candidate) {
w.onICESelectedCandidatePair(agent, c1, c2)
}); err != nil {
return nil, err
}
@@ -327,7 +338,7 @@ func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int)
return
}
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*udpmux.UniversalUDPMuxDefault)
if !ok {
w.log.Warn("invalid udp mux conversion")
return
@@ -354,48 +365,19 @@ func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
}
}()
if !w.shouldSendExtraSrflxCandidate(candidate) {
return
}
// sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
// this is useful when network has an existing port forwarding rule for the wireguard port and this peer
extraSrflx, err := extraSrflxCandidate(candidate)
if err != nil {
w.log.Errorf("failed creating extra server reflexive candidate %s", err)
return
}
w.sentExtraSrflx = true
go func() {
err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
if err != nil {
w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
}
}()
}
func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
func (w *WorkerICE) onICESelectedCandidatePair(agent *icemaker.ThreadSafeAgent, c1, c2 ice.Candidate) {
w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
w.config.Key)
w.muxAgent.Lock()
pair, err := w.agent.GetSelectedCandidatePair()
if err != nil {
w.log.Warnf("failed to get selected candidate pair: %s", err)
w.muxAgent.Unlock()
pairStat, ok := agent.GetSelectedCandidatePairStats()
if !ok {
w.log.Warnf("failed to get selected candidate pair stats")
return
}
if pair == nil {
w.log.Warnf("selected candidate pair is nil, cannot proceed")
w.muxAgent.Unlock()
return
}
w.muxAgent.Unlock()
duration := time.Duration(pair.CurrentRoundTripTime() * float64(time.Second))
duration := time.Duration(pairStat.CurrentRoundTripTime * float64(time.Second))
if err := w.statusRecorder.UpdateLatency(w.config.Key, duration); err != nil {
w.log.Debugf("failed to update latency for peer: %s", err)
return
@@ -424,22 +406,31 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
}
}
func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
return true
}
return false
}
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
isControlling := w.config.LocalKey > w.config.Key
if isControlling {
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
if isController(w.config) {
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}
}
func shouldAddExtraCandidate(candidate ice.Candidate) bool {
if candidate.Type() != ice.CandidateTypeServerReflexive {
return false
}
if candidate.Port() == candidate.RelatedAddress().Port {
return false
}
// in the older version when we didn't set candidate ID extension the remote peer sent the extra candidates
// in newer version we generate locally the extra candidate
if _, ok := candidate.GetExtension(ice.ExtensionKeyCandidateID); !ok {
return false
}
return true
}
func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
relatedAdd := candidate.RelatedAddress()
ec, err := ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
@@ -455,6 +446,10 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
}
for _, e := range candidate.Extensions() {
// overwrite the original candidate ID with the new one to avoid candidate duplication
if e.Key == ice.ExtensionKeyCandidateID {
e.Value = candidate.ID()
}
if err := ec.AddExtension(e); err != nil {
return nil, err
}

View File

@@ -12,7 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// ProbeResult holds the info about the result of a relay probe request

View File

@@ -36,9 +36,9 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
)
@@ -108,6 +108,10 @@ func NewManager(config ManagerConfig) *DefaultManager {
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
if runtime.GOOS == "windows" && config.WGInterface != nil {
nbnet.SetVPNInterfaceName(config.WGInterface.Name())
}
dm := &DefaultManager{
ctx: mCTX,
stop: cancel,
@@ -208,7 +212,7 @@ func (m *DefaultManager) Init() error {
return nil
}
if err := m.sysOps.CleanupRouting(nil); err != nil {
if err := m.sysOps.CleanupRouting(nil, nbnet.AdvancedRouting()); err != nil {
log.Warnf("Failed cleaning up routing: %v", err)
}
@@ -219,7 +223,7 @@ func (m *DefaultManager) Init() error {
ips := resolveURLsToIPs(initialAddresses)
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
if err := m.sysOps.SetupRouting(ips, m.stateManager, nbnet.AdvancedRouting()); err != nil {
return fmt.Errorf("setup routing: %w", err)
}
@@ -285,11 +289,15 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
}
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
if err := m.sysOps.CleanupRouting(stateManager, nbnet.AdvancedRouting()); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
log.Info("Routing cleanup complete")
}
if runtime.GOOS == "windows" {
nbnet.SetVPNInterfaceName("")
}
}
m.mux.Lock()

View File

@@ -12,11 +12,11 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
return nil
}
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
return nil
}

View File

@@ -3,7 +3,6 @@
package systemops
import (
"context"
"errors"
"fmt"
"net"
@@ -22,7 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/client/net/hooks"
)
const localSubnetsCacheTTL = 15 * time.Minute
@@ -96,9 +95,9 @@ func (r *SysOps) cleanupRefCounter(stateManager *statemanager.Manager) error {
return nil
}
// TODO: Remove hooks selectively
nbnet.RemoveDialerHooks()
nbnet.RemoveListenerHooks()
hooks.RemoveWriteHooks()
hooks.RemoveCloseHooks()
hooks.RemoveAddressRemoveHooks()
if err := r.refCounter.Flush(); err != nil {
return fmt.Errorf("flush route manager: %w", err)
@@ -290,12 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
}
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
return fmt.Errorf("convert ip to prefix: %w", err)
}
beforeHook := func(connID hooks.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
@@ -304,7 +298,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
afterHook := func(connID hooks.ConnectionID) error {
if err := r.refCounter.DecrementWithID(string(connID)); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}
@@ -317,36 +311,20 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
var merr *multierror.Error
for _, ip := range initAddresses {
if err := beforeHook("init", ip); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("invalid IP address %s: %w", ip, err))
continue
}
if err := beforeHook("init", prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", prefix, err))
}
}
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
if ctx.Err() != nil {
return ctx.Err()
}
hooks.AddWriteHook(beforeHook)
hooks.AddCloseHook(afterHook)
var merr *multierror.Error
for _, ip := range resolvedIPs {
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
}
return nberrors.FormatErrorOrNil(merr)
})
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
return afterHook(connID)
})
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
return beforeHook(connID, ip.IP)
})
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
return afterHook(connID)
})
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
hooks.AddAddressRemoveHook(func(connID hooks.ConnectionID, prefix netip.Prefix) error {
if _, err := r.refCounter.Decrement(prefix); err != nil {
return fmt.Errorf("remove route reference: %w", err)
}

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/client/net"
)
type dialer interface {
@@ -143,10 +144,11 @@ func TestAddVPNRoute(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
r := NewSysOps(wgInterface, nil)
err := r.SetupRouting(nil, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
})
intf, err := net.InterfaceByName(wgInterface.Name())
@@ -341,10 +343,11 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
r := NewSysOps(wgInterface, nil)
err := r.SetupRouting(nil, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
})
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
@@ -484,10 +487,11 @@ func setupTestEnv(t *testing.T) {
})
r := NewSysOps(wgInterface, nil)
err := r.SetupRouting(nil, nil)
advancedRouting := nbnet.AdvancedRouting()
err := r.SetupRouting(nil, nil, advancedRouting)
require.NoError(t, err, "setupRouting should not return err")
t.Cleanup(func() {
assert.NoError(t, r.CleanupRouting(nil))
assert.NoError(t, r.CleanupRouting(nil, advancedRouting))
})
index, err := net.InterfaceByName(wgInterface.Name())

View File

@@ -12,14 +12,14 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager, bool) error {
r.mu.Lock()
defer r.mu.Unlock()
r.prefixes = make(map[netip.Prefix]struct{})
return nil
}
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
func (r *SysOps) CleanupRouting(*statemanager.Manager, bool) error {
r.mu.Lock()
defer r.mu.Unlock()

View File

@@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/sysctl"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// IPRule contains IP rule information for debugging
@@ -94,15 +94,15 @@ func getSetupRules() []ruleParams {
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
// This table is where a default route or other specific routes received from the management server are configured,
// enabling VPN connectivity.
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
if !nbnet.AdvancedRouting() {
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) (err error) {
if !advancedRouting {
log.Infof("Using legacy routing setup")
return r.setupRefCounter(initAddresses, stateManager)
}
defer func() {
if err != nil {
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
if cleanErr := r.CleanupRouting(stateManager, advancedRouting); cleanErr != nil {
log.Errorf("Error cleaning up routing: %v", cleanErr)
}
}
@@ -132,8 +132,8 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
// The function uses error aggregation to report any errors encountered during the cleanup process.
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
if !nbnet.AdvancedRouting() {
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if !advancedRouting {
return r.cleanupRefCounter(stateManager)
}

View File

@@ -20,11 +20,11 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
return r.setupRefCounter(initAddresses, stateManager)
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
return r.cleanupRefCounter(stateManager)
}

View File

@@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
type PacketExpectation struct {

View File

@@ -8,6 +8,7 @@ import (
"net/netip"
"os"
"runtime/debug"
"sort"
"strconv"
"sync"
"syscall"
@@ -19,9 +20,16 @@ import (
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
)
const InfiniteLifetime = 0xffffffff
func init() {
nbnet.GetBestInterfaceFunc = GetBestInterface
}
const (
InfiniteLifetime = 0xffffffff
)
type RouteUpdateType int
@@ -77,6 +85,14 @@ type MIB_IPFORWARD_TABLE2 struct {
Table [1]MIB_IPFORWARD_ROW2 // Flexible array member
}
// candidateRoute represents a potential route for selection during route lookup
type candidateRoute struct {
interfaceIndex uint32
prefixLength uint8
routeMetric uint32
interfaceMetric int
}
// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix
type IP_ADDRESS_PREFIX struct {
Prefix SOCKADDR_INET
@@ -177,11 +193,20 @@ const (
RouteDeleted
)
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return nil
}
log.Infof("Using legacy routing setup with ref counters")
return r.setupRefCounter(initAddresses, stateManager)
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager) error {
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return nil
}
return r.cleanupRefCounter(stateManager)
}
@@ -635,10 +660,7 @@ func getWindowsRoutingTable() (*MIB_IPFORWARD_TABLE2, error) {
func freeWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) {
if table != nil {
ret, _, _ := procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
if ret != 0 {
log.Warnf("FreeMibTable failed with return code: %d", ret)
}
_, _, _ = procFreeMibTable.Call(uintptr(unsafe.Pointer(table)))
}
}
@@ -652,8 +674,7 @@ func parseWindowsRoutingTable(table *MIB_IPFORWARD_TABLE2) []DetailedRoute {
entryPtr := basePtr + uintptr(i)*entrySize
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
detailed := buildWindowsDetailedRoute(entry)
if detailed != nil {
if detailed := buildWindowsDetailedRoute(entry); detailed != nil {
detailedRoutes = append(detailedRoutes, *detailed)
}
}
@@ -802,6 +823,46 @@ func addZone(ip netip.Addr, interfaceIndex int) netip.Addr {
return ip
}
// parseCandidatesFromTable extracts all matching candidate routes from the routing table
func parseCandidatesFromTable(table *MIB_IPFORWARD_TABLE2, dest netip.Addr, skipInterfaceIndex int) []candidateRoute {
var candidates []candidateRoute
entrySize := unsafe.Sizeof(MIB_IPFORWARD_ROW2{})
basePtr := uintptr(unsafe.Pointer(&table.Table[0]))
for i := uint32(0); i < table.NumEntries; i++ {
entryPtr := basePtr + uintptr(i)*entrySize
entry := (*MIB_IPFORWARD_ROW2)(unsafe.Pointer(entryPtr))
if candidate := parseCandidateRoute(entry, dest, skipInterfaceIndex); candidate != nil {
candidates = append(candidates, *candidate)
}
}
return candidates
}
// parseCandidateRoute extracts candidate route information from a MIB_IPFORWARD_ROW2 entry
// Returns nil if the route doesn't match the destination or should be skipped
func parseCandidateRoute(entry *MIB_IPFORWARD_ROW2, dest netip.Addr, skipInterfaceIndex int) *candidateRoute {
if skipInterfaceIndex > 0 && int(entry.InterfaceIndex) == skipInterfaceIndex {
return nil
}
destPrefix := parseIPPrefix(entry.DestinationPrefix, int(entry.InterfaceIndex))
if !destPrefix.IsValid() || !destPrefix.Contains(dest) {
return nil
}
interfaceMetric := getInterfaceMetric(entry.InterfaceIndex, entry.DestinationPrefix.Prefix.sin6_family)
return &candidateRoute{
interfaceIndex: entry.InterfaceIndex,
prefixLength: entry.DestinationPrefix.PrefixLength,
routeMetric: entry.Metric,
interfaceMetric: interfaceMetric,
}
}
// getInterfaceMetric retrieves the interface metric for a given interface and address family
func getInterfaceMetric(interfaceIndex uint32, family int16) int {
if interfaceIndex == 0 {
@@ -821,6 +882,76 @@ func getInterfaceMetric(interfaceIndex uint32, family int16) int {
return int(ipInterfaceRow.Metric)
}
// sortRouteCandidates sorts route candidates by priority: prefix length -> route metric -> interface metric
func sortRouteCandidates(candidates []candidateRoute) {
sort.Slice(candidates, func(i, j int) bool {
if candidates[i].prefixLength != candidates[j].prefixLength {
return candidates[i].prefixLength > candidates[j].prefixLength
}
if candidates[i].routeMetric != candidates[j].routeMetric {
return candidates[i].routeMetric < candidates[j].routeMetric
}
return candidates[i].interfaceMetric < candidates[j].interfaceMetric
})
}
// GetBestInterface finds the best interface for reaching a destination,
// excluding the VPN interface to avoid routing loops.
//
// Route selection priority:
// 1. Longest prefix match (most specific route)
// 2. Lowest route metric
// 3. Lowest interface metric
func GetBestInterface(dest netip.Addr, vpnIntf string) (*net.Interface, error) {
var skipInterfaceIndex int
if vpnIntf != "" {
if iface, err := net.InterfaceByName(vpnIntf); err == nil {
skipInterfaceIndex = iface.Index
} else {
// not critical, if we cannot get ahold of the interface then we won't need to skip it
log.Warnf("failed to get VPN interface %s: %v", vpnIntf, err)
}
}
table, err := getWindowsRoutingTable()
if err != nil {
return nil, fmt.Errorf("get routing table: %w", err)
}
defer freeWindowsRoutingTable(table)
candidates := parseCandidatesFromTable(table, dest, skipInterfaceIndex)
if len(candidates) == 0 {
return nil, fmt.Errorf("no route to %s", dest)
}
// Sort routes: prefix length -> route metric -> interface metric
sortRouteCandidates(candidates)
for _, candidate := range candidates {
iface, err := net.InterfaceByIndex(int(candidate.interfaceIndex))
if err != nil {
log.Warnf("failed to get interface by index %d: %v", candidate.interfaceIndex, err)
continue
}
if iface.Flags&net.FlagLoopback != 0 && !dest.IsLoopback() {
continue
}
if iface.Flags&net.FlagUp == 0 {
log.Debugf("interface %s is down, trying next route", iface.Name)
continue
}
log.Debugf("route lookup for %s: selected interface %s (index %d), route metric %d, interface metric %d",
dest, iface.Name, iface.Index, candidate.routeMetric, candidate.interfaceMetric)
return iface, nil
}
return nil, fmt.Errorf("no usable interface found for %s", dest)
}
// formatRouteAge formats the route age in seconds to a human-readable string
func formatRouteAge(ageSeconds uint32) string {
if ageSeconds == 0 {

View File

@@ -15,7 +15,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
var (

View File

@@ -12,18 +12,8 @@ func GetPrefixFromIP(ip net.IP) (netip.Prefix, error) {
if !ok {
return netip.Prefix{}, fmt.Errorf("parse IP address: %s", ip)
}
addr = addr.Unmap()
var prefixLength int
switch {
case addr.Is4():
prefixLength = 32
case addr.Is6():
prefixLength = 128
default:
return netip.Prefix{}, fmt.Errorf("invalid IP address: %s", addr)
}
prefix := netip.PrefixFrom(addr, prefixLength)
prefix := netip.PrefixFrom(addr, addr.BitLen())
return prefix, nil
}

View File

@@ -5,7 +5,7 @@ import (
"github.com/pion/transport/v3"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// Dial connects to the address on the named network.

View File

@@ -6,7 +6,7 @@ import (
"github.com/pion/transport/v3"
nbnet "github.com/netbirdio/netbird/util/net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// ListenPacket listens for incoming packets on the given network and address.

View File

@@ -0,0 +1,98 @@
package internal
import (
"context"
"errors"
"fmt"
"net"
"runtime"
"time"
log "github.com/sirupsen/logrus"
)
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
// if the interface is deleted externally while the engine is running.
type WGIfaceMonitor struct {
done chan struct{}
}
// NewWGIfaceMonitor creates a new WGIfaceMonitor instance.
func NewWGIfaceMonitor() *WGIfaceMonitor {
return &WGIfaceMonitor{
done: make(chan struct{}),
}
}
// Start begins monitoring the WireGuard interface.
// It relies on the provided context cancellation to stop.
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
defer close(m.done)
// Skip on mobile platforms as they handle interface lifecycle differently
if runtime.GOOS == "android" || runtime.GOOS == "ios" {
log.Debugf("Interface monitor: skipped on %s platform", runtime.GOOS)
return false, errors.New("not supported on mobile platforms")
}
if ifaceName == "" {
log.Debugf("Interface monitor: empty interface name, skipping monitor")
return false, errors.New("empty interface name")
}
// Get initial interface index to track the specific interface instance
expectedIndex, err := getInterfaceIndex(ifaceName)
if err != nil {
log.Debugf("Interface monitor: interface %s not found, skipping monitor", ifaceName)
return false, fmt.Errorf("interface %s not found: %w", ifaceName, err)
}
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
case <-ticker.C:
currentIndex, err := getInterfaceIndex(ifaceName)
if err != nil {
// Interface was deleted
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
}
// Check if interface index changed (interface was recreated)
if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
}
}
}
// getInterfaceIndex returns the index of a network interface by name.
// Returns an error if the interface is not found.
func getInterfaceIndex(name string) (int, error) {
if name == "" {
return 0, fmt.Errorf("empty interface name")
}
ifi, err := net.InterfaceByName(name)
if err != nil {
// Check if it's specifically a "not found" error
if errors.Is(err, &net.OpError{}) {
// On some systems, this might be a "not found" error
return 0, fmt.Errorf("interface not found: %w", err)
}
return 0, fmt.Errorf("failed to lookup interface: %w", err)
}
if ifi == nil {
return 0, fmt.Errorf("interface not found")
}
return ifi.Index, nil
}

49
client/net/conn.go Normal file
View File

@@ -0,0 +1,49 @@
//go:build !ios
package net
import (
"io"
"net"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/net/hooks"
)
// Conn wraps a net.Conn to override the Close method
type Conn struct {
net.Conn
ID hooks.ConnectionID
}
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
func (c *Conn) Close() error {
return closeConn(c.ID, c.Conn)
}
// TCPConn wraps net.TCPConn to override its Close method to include hook functionality.
type TCPConn struct {
*net.TCPConn
ID hooks.ConnectionID
}
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
func (c *TCPConn) Close() error {
return closeConn(c.ID, c.TCPConn)
}
// closeConn is a helper function to close connections and execute close hooks.
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
err := conn.Close()
closeHooks := hooks.GetCloseHooks()
for _, hook := range closeHooks {
if err := hook(id); err != nil {
log.Errorf("Error executing close hook: %v", err)
}
}
return err
}

82
client/net/dial.go Normal file
View File

@@ -0,0 +1,82 @@
//go:build !ios
package net
import (
"fmt"
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
func DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.DialUDP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
}
switch c := conn.(type) {
case *net.UDPConn:
// Advanced routing: plain connection
return c, nil
case *Conn:
// Legacy routing: wrapped connection preserves close hooks
udpConn, ok := c.Conn.(*net.UDPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDP connection, got %T", c.Conn)
}
return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
}
if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err)
}
return nil, fmt.Errorf("unexpected connection type: %T", conn)
}
func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
if CustomRoutingDisabled() {
return net.DialTCP(network, laddr, raddr)
}
dialer := NewDialer()
dialer.LocalAddr = laddr
conn, err := dialer.Dial(network, raddr.String())
if err != nil {
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
}
switch c := conn.(type) {
case *net.TCPConn:
// Advanced routing: plain connection
return c, nil
case *Conn:
// Legacy routing: wrapped connection preserves close hooks
tcpConn, ok := c.Conn.(*net.TCPConn)
if !ok {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected TCP connection, got %T", c.Conn)
}
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
}
if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err)
}
return nil, fmt.Errorf("unexpected connection type: %T", conn)
}

View File

@@ -16,6 +16,5 @@ func NewDialer() *Dialer {
Dialer: &net.Dialer{},
}
dialer.init()
return dialer
}

87
client/net/dialer_dial.go Normal file
View File

@@ -0,0 +1,87 @@
//go:build !ios
package net
import (
"context"
"fmt"
"net"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/net/hooks"
)
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
log.Debugf("Dialing %s %s", network, address)
if CustomRoutingDisabled() || AdvancedRouting() {
return d.Dialer.DialContext(ctx, network, address)
}
connID := hooks.GenerateConnID()
if err := callDialerHooks(ctx, connID, address, d.Resolver); err != nil {
log.Errorf("Failed to call dialer hooks: %v", err)
}
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
}
// Wrap the connection in Conn to handle Close with hooks
return &Conn{Conn: conn, ID: connID}, nil
}
// Dial wraps the net.Dialer's Dial method to use the custom connection
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
return d.DialContext(context.Background(), network, address)
}
func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address string, customResolver *net.Resolver) error {
if ctx.Err() != nil {
return ctx.Err()
}
writeHooks := hooks.GetWriteHooks()
if len(writeHooks) == 0 {
return nil
}
host, _, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("split host and port: %w", err)
}
resolver := customResolver
if resolver == nil {
resolver = net.DefaultResolver
}
ips, err := resolver.LookupIPAddr(ctx, host)
if err != nil {
return fmt.Errorf("failed to resolve address %s: %w", address, err)
}
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
var merr *multierror.Error
for _, ip := range ips {
prefix, err := util.GetPrefixFromIP(ip.IP)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("convert IP %s to prefix: %w", ip.IP, err))
continue
}
for _, hook := range writeHooks {
if err := hook(connID, prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("executing dial hook for IP %s: %w", ip.IP, err))
}
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -0,0 +1,7 @@
//go:build !linux && !windows
package net
func (d *Dialer) init() {
// implemented on Linux, Android, and Windows only
}

View File

@@ -0,0 +1,5 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = applyUnicastIFToSocket
}

View File

@@ -11,6 +11,7 @@ import (
const (
envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING"
envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
)
// CustomRoutingDisabled returns true if custom routing is disabled.

24
client/net/env_android.go Normal file
View File

@@ -0,0 +1,24 @@
//go:build android
package net
// Init initializes the network environment for Android
func Init() {
// No initialization needed on Android
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on Android since we cannot handle routes dynamically.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on Android
func SetVPNInterfaceName(name string) {
// No-op on Android - not needed for Android VPN service
}
// GetVPNInterfaceName returns empty string on Android
func GetVPNInterfaceName() string {
return ""
}

23
client/net/env_generic.go Normal file
View File

@@ -0,0 +1,23 @@
//go:build !linux && !windows && !android
package net
// Init initializes the network environment (no-op on non-Linux/Windows platforms)
func Init() {
// No-op on non-Linux/Windows platforms
}
// AdvancedRouting returns false on non-Linux/Windows platforms
func AdvancedRouting() bool {
return false
}
// SetVPNInterfaceName is a no-op on non-Windows platforms
func SetVPNInterfaceName(name string) {
// No-op on non-Windows platforms
}
// GetVPNInterfaceName returns empty string on non-Windows platforms
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -18,7 +18,6 @@ import (
const (
// these have the same effect, skip socket env supported for backward compatibility
envSkipSocketMark = "NB_SKIP_SOCKET_MARK"
envUseLegacyRouting = "NB_USE_LEGACY_ROUTING"
)
var advancedRoutingSupported bool
@@ -27,6 +26,7 @@ func Init() {
advancedRoutingSupported = checkAdvancedRoutingSupport()
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
func AdvancedRouting() bool {
return advancedRoutingSupported
}
@@ -73,7 +73,7 @@ func checkAdvancedRoutingSupport() bool {
}
func CheckFwmarkSupport() bool {
// temporarily enable advanced routing to check fwmarks are supported
// temporarily enable advanced routing to check if fwmarks are supported
old := advancedRoutingSupported
advancedRoutingSupported = true
defer func() {
@@ -129,3 +129,13 @@ func CheckRuleOperationsSupport() bool {
}
return true
}
// SetVPNInterfaceName is a no-op on Linux
func SetVPNInterfaceName(name string) {
// No-op on Linux - not needed for fwmark-based routing
}
// GetVPNInterfaceName returns empty string on Linux
func GetVPNInterfaceName() string {
return ""
}

67
client/net/env_windows.go Normal file
View File

@@ -0,0 +1,67 @@
//go:build windows
package net
import (
"os"
"strconv"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
)
var (
vpnInterfaceName string
vpnInitMutex sync.RWMutex
advancedRoutingSupported bool
)
func Init() {
advancedRoutingSupported = checkAdvancedRoutingSupport()
}
func checkAdvancedRoutingSupport() bool {
var err error
var legacyRouting bool
if val := os.Getenv(envUseLegacyRouting); val != "" {
legacyRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
}
}
if legacyRouting || netstack.IsEnabled() {
log.Info("advanced routing has been requested to be disabled")
return false
}
log.Info("system supports advanced routing")
return true
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes
func AdvancedRouting() bool {
return advancedRoutingSupported
}
// GetVPNInterfaceName returns the stored VPN interface name
func GetVPNInterfaceName() string {
vpnInitMutex.RLock()
defer vpnInitMutex.RUnlock()
return vpnInterfaceName
}
// SetVPNInterfaceName sets the VPN interface name for lazy initialization
func SetVPNInterfaceName(name string) {
vpnInitMutex.Lock()
defer vpnInitMutex.Unlock()
vpnInterfaceName = name
if name != "" {
log.Infof("VPN interface name set to %s for route exclusion", name)
}
}

93
client/net/hooks/hooks.go Normal file
View File

@@ -0,0 +1,93 @@
package hooks
import (
"net/netip"
"slices"
"sync"
"github.com/google/uuid"
)
// ConnectionID provides a globally unique identifier for network connections.
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
type ConnectionID string
// GenerateConnID generates a unique identifier for each connection.
func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}
type WriteHookFunc func(connID ConnectionID, prefix netip.Prefix) error
type CloseHookFunc func(connID ConnectionID) error
type AddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
var (
hooksMutex sync.RWMutex
writeHooks []WriteHookFunc
closeHooks []CloseHookFunc
addressRemoveHooks []AddressRemoveHookFunc
)
// AddWriteHook allows adding a new hook to be executed before writing/dialing.
func AddWriteHook(hook WriteHookFunc) {
hooksMutex.Lock()
defer hooksMutex.Unlock()
writeHooks = append(writeHooks, hook)
}
// AddCloseHook allows adding a new hook to be executed on connection close.
func AddCloseHook(hook CloseHookFunc) {
hooksMutex.Lock()
defer hooksMutex.Unlock()
closeHooks = append(closeHooks, hook)
}
// RemoveWriteHooks removes all write hooks.
func RemoveWriteHooks() {
hooksMutex.Lock()
defer hooksMutex.Unlock()
writeHooks = nil
}
// RemoveCloseHooks removes all close hooks.
func RemoveCloseHooks() {
hooksMutex.Lock()
defer hooksMutex.Unlock()
closeHooks = nil
}
// AddAddressRemoveHook allows adding a new hook to be executed when an address is removed.
func AddAddressRemoveHook(hook AddressRemoveHookFunc) {
hooksMutex.Lock()
defer hooksMutex.Unlock()
addressRemoveHooks = append(addressRemoveHooks, hook)
}
// RemoveAddressRemoveHooks removes all listener address hooks.
func RemoveAddressRemoveHooks() {
hooksMutex.Lock()
defer hooksMutex.Unlock()
addressRemoveHooks = nil
}
// GetWriteHooks returns a copy of the current write hooks.
func GetWriteHooks() []WriteHookFunc {
hooksMutex.RLock()
defer hooksMutex.RUnlock()
return slices.Clone(writeHooks)
}
// GetCloseHooks returns a copy of the current close hooks.
func GetCloseHooks() []CloseHookFunc {
hooksMutex.RLock()
defer hooksMutex.RUnlock()
return slices.Clone(closeHooks)
}
// GetAddressRemoveHooks returns a copy of the current listener address remove hooks.
func GetAddressRemoveHooks() []AddressRemoveHookFunc {
hooksMutex.RLock()
defer hooksMutex.RUnlock()
return slices.Clone(addressRemoveHooks)
}

47
client/net/listen.go Normal file
View File

@@ -0,0 +1,47 @@
//go:build !ios
package net
import (
"context"
"fmt"
"net"
"sync"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
)
// ListenUDP listens on the network address and returns a transport.UDPConn
// which includes support for write and close hooks.
func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) {
if CustomRoutingDisabled() {
return net.ListenUDP(network, laddr)
}
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
if err != nil {
return nil, fmt.Errorf("listen UDP: %w", err)
}
switch c := conn.(type) {
case *net.UDPConn:
// Advanced routing: plain connection
return c, nil
case *PacketConn:
// Legacy routing: wrapped connection for hooks
udpConn, ok := c.PacketConn.(*net.UDPConn)
if !ok {
if err := c.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
return nil, fmt.Errorf("expected UDPConn, got %T", c.PacketConn)
}
return &UDPConn{UDPConn: udpConn, ID: c.ID, seenAddrs: &sync.Map{}}, nil
}
if err := conn.Close(); err != nil {
log.Errorf("failed to close connection: %v", err)
}
return nil, fmt.Errorf("unexpected connection type: %T", conn)
}

View File

@@ -7,14 +7,12 @@ import (
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
type ListenerConfig struct {
*net.ListenConfig
net.ListenConfig
}
// NewListener creates a new ListenerConfig instance.
func NewListener() *ListenerConfig {
listener := &ListenerConfig{
ListenConfig: &net.ListenConfig{},
}
listener := &ListenerConfig{}
listener.init()
return listener

View File

@@ -0,0 +1,7 @@
//go:build !linux && !windows
package net
func (l *ListenerConfig) init() {
// implemented on Linux, Android, and Windows only
}

View File

@@ -0,0 +1,8 @@
package net
func (l *ListenerConfig) init() {
// TODO: this will select a single source interface, but for UDP we can have various source interfaces and IP addresses.
// For now we stick to the one that matches the request IP address, which can be the unspecified IP. In this case
// the interface will be selected that serves the default route.
l.ListenConfig.Control = applyUnicastIFToSocket
}

View File

@@ -0,0 +1,153 @@
//go:build !ios
package net
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/net/hooks"
)
// ListenPacket listens on the network address and returns a PacketConn
// which includes support for write hooks.
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
if CustomRoutingDisabled() || AdvancedRouting() {
return l.ListenConfig.ListenPacket(ctx, network, address)
}
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
if err != nil {
return nil, fmt.Errorf("listen packet: %w", err)
}
connID := hooks.GenerateConnID()
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
}
// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
type PacketConn struct {
net.PacketConn
ID hooks.ConnectionID
seenAddrs *sync.Map
}
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil {
log.Errorf("Failed to call write hooks: %v", err)
}
return c.PacketConn.WriteTo(b, addr)
}
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
func (c *PacketConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.PacketConn)
}
// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
type UDPConn struct {
*net.UDPConn
ID hooks.ConnectionID
seenAddrs *sync.Map
}
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
if err := callWriteHooks(c.ID, c.seenAddrs, addr); err != nil {
log.Errorf("Failed to call write hooks: %v", err)
}
return c.UDPConn.WriteTo(b, addr)
}
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
func (c *UDPConn) Close() error {
defer c.seenAddrs.Clear()
return closeConn(c.ID, c.UDPConn)
}
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
func (c *PacketConn) RemoveAddress(addr string) {
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
return
}
ipStr, _, err := net.SplitHostPort(addr)
if err != nil {
log.Errorf("Error splitting IP address and port: %v", err)
return
}
ipAddr, err := netip.ParseAddr(ipStr)
if err != nil {
log.Errorf("Error parsing IP address %s: %v", ipStr, err)
return
}
prefix := netip.PrefixFrom(ipAddr.Unmap(), ipAddr.BitLen())
addressRemoveHooks := hooks.GetAddressRemoveHooks()
if len(addressRemoveHooks) == 0 {
return
}
for _, hook := range addressRemoveHooks {
if err := hook(c.ID, prefix); err != nil {
log.Errorf("Error executing listener address remove hook: %v", err)
}
}
}
// WrapPacketConn wraps an existing net.PacketConn with nbnet hook functionality
func WrapPacketConn(conn net.PacketConn) net.PacketConn {
if AdvancedRouting() {
// hooks not required for advanced routing
return conn
}
return &PacketConn{
PacketConn: conn,
ID: hooks.GenerateConnID(),
seenAddrs: &sync.Map{},
}
}
func callWriteHooks(id hooks.ConnectionID, seenAddrs *sync.Map, addr net.Addr) error {
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); loaded {
return nil
}
writeHooks := hooks.GetWriteHooks()
if len(writeHooks) == 0 {
return nil
}
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return fmt.Errorf("expected *net.UDPAddr for packet connection, got %T", addr)
}
prefix, err := util.GetPrefixFromIP(udpAddr.IP)
if err != nil {
return fmt.Errorf("convert UDP IP %s to prefix: %w", udpAddr.IP, err)
}
log.Debugf("Listener resolved IP for %s: %s", addr, prefix)
var merr *multierror.Error
for _, hook := range writeHooks {
if err := hook(id, prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("execute write hook: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -5,8 +5,6 @@ import (
"math/big"
"net"
"net/netip"
"github.com/google/uuid"
)
const (
@@ -44,18 +42,6 @@ func IsDataPlaneMark(fwmark uint32) bool {
return fwmark >= DataPlaneMarkLower && fwmark <= DataPlaneMarkUpper
}
// ConnectionID provides a globally unique identifier for network connections.
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
type ConnectionID string
type AddHookFunc func(connID ConnectionID, IP net.IP) error
type RemoveHookFunc func(connID ConnectionID) error
// GenerateConnID generates a unique identifier for each connection.
func GenerateConnID() ConnectionID {
return ConnectionID(uuid.NewString())
}
func GetLastIPFromNetwork(network netip.Prefix, fromEnd int) (netip.Addr, error) {
var endIP net.IP
addr := network.Addr().AsSlice()

284
client/net/net_windows.go Normal file
View File

@@ -0,0 +1,284 @@
package net
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"syscall"
"time"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
)
const (
// https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options
IpUnicastIf = 31
Ipv6UnicastIf = 31
// https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ipv6-socket-options
Ipv6V6only = 27
)
// GetBestInterfaceFunc is set at runtime to avoid import cycle
var GetBestInterfaceFunc func(dest netip.Addr, vpnIntf string) (*net.Interface, error)
// nativeToBigEndian converts a uint32 from native byte order to big-endian
func nativeToBigEndian(v uint32) uint32 {
return (v&0xff)<<24 | (v&0xff00)<<8 | (v&0xff0000)>>8 | (v&0xff000000)>>24
}
// parseDestinationAddress parses the destination address from various formats
func parseDestinationAddress(network, address string) (netip.Addr, error) {
if address == "" {
if strings.HasSuffix(network, "6") {
return netip.IPv6Unspecified(), nil
}
return netip.IPv4Unspecified(), nil
}
if addrPort, err := netip.ParseAddrPort(address); err == nil {
return addrPort.Addr(), nil
}
if dest, err := netip.ParseAddr(address); err == nil {
return dest, nil
}
host, _, err := net.SplitHostPort(address)
if err != nil {
// No port, treat whole string as host
host = address
}
if host == "" {
if strings.HasSuffix(network, "6") {
return netip.IPv6Unspecified(), nil
}
return netip.IPv4Unspecified(), nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil || len(ips) == 0 {
return netip.Addr{}, fmt.Errorf("resolve destination %s: %w", host, err)
}
dest, ok := netip.AddrFromSlice(ips[0].IP)
if !ok {
return netip.Addr{}, fmt.Errorf("convert IP %v to netip.Addr", ips[0].IP)
}
if ips[0].Zone != "" {
dest = dest.WithZone(ips[0].Zone)
}
return dest, nil
}
func getInterfaceFromZone(zone string) *net.Interface {
if zone == "" {
return nil
}
idx, err := strconv.Atoi(zone)
if err != nil {
log.Debugf("invalid zone format for Windows (expected numeric): %s", zone)
return nil
}
iface, err := net.InterfaceByIndex(idx)
if err != nil {
log.Debugf("failed to get interface by index %d from zone: %v", idx, err)
return nil
}
return iface
}
type interfaceSelection struct {
iface4 *net.Interface
iface6 *net.Interface
}
func selectInterfaceForZone(dest netip.Addr, zone string) *interfaceSelection {
iface := getInterfaceFromZone(zone)
if iface == nil {
return nil
}
if dest.Is6() {
return &interfaceSelection{iface6: iface}
}
return &interfaceSelection{iface4: iface}
}
func selectInterfaceForUnspecified() (*interfaceSelection, error) {
if GetBestInterfaceFunc == nil {
return nil, errors.New("GetBestInterfaceFunc not initialized")
}
var result interfaceSelection
vpnIfaceName := GetVPNInterfaceName()
if iface4, err := GetBestInterfaceFunc(netip.IPv4Unspecified(), vpnIfaceName); err == nil {
result.iface4 = iface4
} else {
log.Debugf("No IPv4 default route found: %v", err)
}
if iface6, err := GetBestInterfaceFunc(netip.IPv6Unspecified(), vpnIfaceName); err == nil {
result.iface6 = iface6
} else {
log.Debugf("No IPv6 default route found: %v", err)
}
if result.iface4 == nil && result.iface6 == nil {
return nil, errors.New("no default routes found")
}
return &result, nil
}
func selectInterface(dest netip.Addr) (*interfaceSelection, error) {
if zone := dest.Zone(); zone != "" {
if selection := selectInterfaceForZone(dest, zone); selection != nil {
return selection, nil
}
}
if dest.IsUnspecified() {
return selectInterfaceForUnspecified()
}
if GetBestInterfaceFunc == nil {
return nil, errors.New("GetBestInterfaceFunc not initialized")
}
iface, err := GetBestInterfaceFunc(dest, GetVPNInterfaceName())
if err != nil {
return nil, fmt.Errorf("find route for %s: %w", dest, err)
}
if dest.Is6() {
return &interfaceSelection{iface6: iface}, nil
}
return &interfaceSelection{iface4: iface}, nil
}
func setIPv4UnicastIF(fd uintptr, iface *net.Interface) error {
ifaceIndexBE := nativeToBigEndian(uint32(iface.Index))
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, IpUnicastIf, int(ifaceIndexBE)); err != nil {
return fmt.Errorf("set IP_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
func setIPv6UnicastIF(fd uintptr, iface *net.Interface) error {
if err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6UnicastIf, iface.Index); err != nil {
return fmt.Errorf("set IPV6_UNICAST_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
func setUnicastIf(fd uintptr, network string, selection *interfaceSelection, address string) error {
// The Go runtime always passes specific network types to Control (udp4, udp6, tcp4, tcp6, etc.)
// Never generic ones (udp, tcp, ip)
switch {
case strings.HasSuffix(network, "4"):
// IPv4-only socket (udp4, tcp4, ip4)
return setUnicastIfIPv4(fd, network, selection, address)
case strings.HasSuffix(network, "6"):
// IPv6 socket (udp6, tcp6, ip6) - could be dual-stack or IPv6-only
return setUnicastIfIPv6(fd, network, selection, address)
}
// Shouldn't reach here based on Go's documented behavior
return fmt.Errorf("unexpected network type: %s", network)
}
func setUnicastIfIPv4(fd uintptr, network string, selection *interfaceSelection, address string) error {
if selection.iface4 == nil {
return nil
}
if err := setIPv4UnicastIF(fd, selection.iface4); err != nil {
return err
}
log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s", selection.iface4.Index, selection.iface4.Name, network, address)
return nil
}
func setUnicastIfIPv6(fd uintptr, network string, selection *interfaceSelection, address string) error {
isDualStack := checkDualStack(fd)
// For dual-stack sockets, also set the IPv4 option
if isDualStack && selection.iface4 != nil {
if err := setIPv4UnicastIF(fd, selection.iface4); err != nil {
return err
}
log.Debugf("Set IP_UNICAST_IF=%d on %s for %s to %s (dual-stack)", selection.iface4.Index, selection.iface4.Name, network, address)
}
if selection.iface6 == nil {
return nil
}
if err := setIPv6UnicastIF(fd, selection.iface6); err != nil {
return err
}
log.Debugf("Set IPV6_UNICAST_IF=%d on %s for %s to %s", selection.iface6.Index, selection.iface6.Name, network, address)
return nil
}
func checkDualStack(fd uintptr) bool {
var v6Only int
v6OnlyLen := int32(unsafe.Sizeof(v6Only))
err := windows.Getsockopt(windows.Handle(fd), windows.IPPROTO_IPV6, Ipv6V6only, (*byte)(unsafe.Pointer(&v6Only)), &v6OnlyLen)
return err == nil && v6Only == 0
}
// applyUnicastIFToSocket applies IpUnicastIf to a socket based on the destination address
func applyUnicastIFToSocket(network string, address string, c syscall.RawConn) error {
if !AdvancedRouting() {
return nil
}
dest, err := parseDestinationAddress(network, address)
if err != nil {
return err
}
dest = dest.Unmap()
if !dest.IsValid() {
return fmt.Errorf("invalid destination address for %s", address)
}
selection, err := selectInterface(dest)
if err != nil {
return err
}
var controlErr error
err = c.Control(func(fd uintptr) {
controlErr = setUnicastIf(fd, network, selection, address)
})
if err != nil {
return fmt.Errorf("control: %w", err)
}
return controlErr
}

View File

@@ -2,7 +2,7 @@
set -eEuo pipefail
: ${NB_ENTRYPOINT_SERVICE_TIMEOUT:="5"}
: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="1"}
: ${NB_ENTRYPOINT_LOGIN_TIMEOUT:="5"}
NETBIRD_BIN="${NETBIRD_BIN:-"netbird"}"
export NB_LOG_FILE="${NB_LOG_FILE:-"console,/var/log/netbird/client.log"}"
service_pids=()
@@ -39,7 +39,7 @@ wait_for_message() {
info "not waiting for log line ${message@Q} due to zero timeout."
elif test -n "${log_file_path}"; then
info "waiting for log line ${message@Q} for ${timeout} seconds..."
grep -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null)
grep -E -q "${message}" <(timeout "${timeout}" tail -F "${log_file_path}" 2>/dev/null)
else
info "log file unsupported, sleeping for ${timeout} seconds..."
sleep "${timeout}"
@@ -81,7 +81,7 @@ wait_for_daemon_startup() {
login_if_needed() {
local timeout="${1}"
if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered'; then
if test -n "${log_file_path}" && wait_for_message "${timeout}" 'peer has been successfully registered|management connection state READY'; then
info "already logged in, skipping 'netbird up'..."
else
info "logging in..."

View File

@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.36.6
// protoc v5.29.3
// protoc v6.32.1
// source: daemon.proto
package proto
@@ -794,6 +794,8 @@ type StatusRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"`
ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"`
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -842,6 +844,13 @@ func (x *StatusRequest) GetShouldRunProbes() bool {
return false
}
func (x *StatusRequest) GetWaitForReady() bool {
if x != nil && x.WaitForReady != nil {
return *x.WaitForReady
}
return false
}
type StatusResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// status of the server.
@@ -4673,10 +4682,12 @@ const file_daemon_proto_rawDesc = "" +
"\f_profileNameB\v\n" +
"\t_username\"\f\n" +
"\n" +
"UpResponse\"g\n" +
"UpResponse\"\xa1\x01\n" +
"\rStatusRequest\x12,\n" +
"\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" +
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" +
"\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" +
"\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" +
"\r_waitForReady\"\x82\x01\n" +
"\x0eStatusResponse\x12\x16\n" +
"\x06status\x18\x01 \x01(\tR\x06status\x122\n" +
"\n" +
@@ -5231,6 +5242,7 @@ func file_daemon_proto_init() {
}
file_daemon_proto_msgTypes[1].OneofWrappers = []any{}
file_daemon_proto_msgTypes[5].OneofWrappers = []any{}
file_daemon_proto_msgTypes[7].OneofWrappers = []any{}
file_daemon_proto_msgTypes[26].OneofWrappers = []any{
(*PortInfo_Port)(nil),
(*PortInfo_Range_)(nil),

View File

@@ -186,6 +186,8 @@ message UpResponse {}
message StatusRequest{
bool getFullPeerStatus = 1;
bool shouldRunProbes = 2;
// the UI do not using this yet, but CLIs could use it to wait until the status is ready
optional bool waitForReady = 3;
}
message StatusResponse{

View File

@@ -65,6 +65,9 @@ type Server struct {
mutex sync.Mutex
config *profilemanager.Config
proto.UnimplementedDaemonServiceServer
clientRunning bool // protected by mutex
clientRunningChan chan struct{}
clientGiveUpChan chan struct{}
connectClient *internal.ConnectClient
@@ -103,6 +106,11 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
func (s *Server) Start() error {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.clientRunning {
return nil
}
state := internal.CtxGetState(s.rootCtx)
if err := handlePanicLog(); err != nil {
@@ -172,8 +180,10 @@ func (s *Server) Start() error {
return nil
}
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
return nil
}
@@ -204,12 +214,22 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status,
runningChan chan struct{},
) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
defer func() {
s.mutex.Lock()
s.clientRunning = false
s.mutex.Unlock()
}()
if s.config.DisableAutoConnect {
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
log.Debugf("run client connection exited with error: %v", err)
}
log.Tracef("client connection exited")
return
}
backOff := getConnectWithBackoff(ctx)
go func() {
t := time.NewTicker(24 * time.Hour)
for {
@@ -218,89 +238,36 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanage
t.Stop()
return
case <-t.C:
if retryStarted {
mgmtState := statusRecorder.GetManagementState()
signalState := statusRecorder.GetSignalState()
if mgmtState.Connected && signalState.Connected {
log.Tracef("resetting status")
retryStarted = false
backOff.Reset()
} else {
log.Tracef("not resetting status: mgmt: %v, signal: %v", mgmtState.Connected, signalState.Connected)
}
}
}
}
}()
runOperation := func() error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
err := s.connectClient.Run(runningChan)
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
return err
}
if config.DisableAutoConnect {
return backoff.Permanent(err)
log.Tracef("client connection exited gracefully, do not need to retry")
return nil
}
if !retryStarted {
retryStarted = true
backOff.Reset()
if err := backoff.Retry(runOperation, backOff); err != nil {
log.Errorf("operation failed: %v", err)
}
log.Tracef("client connection exited")
return fmt.Errorf("client connection exited")
if giveUpChan != nil {
close(giveUpChan)
}
err := backoff.Retry(runOperation, backOff)
if s, ok := gstatus.FromError(err); ok && s.Code() != codes.Canceled {
log.Errorf("received an error when trying to connect: %v", err)
} else {
log.Tracef("retry canceled")
}
}
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
multiplier := defaultRetryMultiplier
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
// parse the multiplier from the environment variable string value to float64
value, err := strconv.ParseFloat(envValue, 64)
if err != nil {
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
} else {
multiplier = value
}
}
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: initialInterval,
RandomizationFactor: 1,
Multiplier: multiplier,
MaxInterval: maxInterval,
MaxElapsedTime: maxElapsedTime, // 14 days
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// parseEnvDuration parses the environment variable and returns the duration
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
if envValue := os.Getenv(envVar); envValue != "" {
if duration, err := time.ParseDuration(envValue); err == nil {
return duration
}
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
}
return defaultDuration
}
// loginAttempt attempts to login using the provided information. it returns a status in case something fails
@@ -419,7 +386,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
if s.actCancel != nil {
s.actCancel()
}
ctx, cancel := context.WithCancel(s.rootCtx)
ctx, cancel := context.WithCancel(callerCtx)
md, ok := metadata.FromIncomingContext(callerCtx)
if ok {
@@ -429,11 +396,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
s.actCancel = cancel
s.mutex.Unlock()
if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil {
if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil {
log.Warnf(errRestoreResidualState, err)
}
state := internal.CtxGetState(ctx)
state := internal.CtxGetState(s.rootCtx)
defer func() {
status, err := state.Status()
if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) {
@@ -646,6 +613,20 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
// Up starts engine work in the daemon.
func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) {
s.mutex.Lock()
if s.clientRunning {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
s.mutex.Unlock()
return nil, err
}
if status == internal.StatusNeedsLogin {
s.actCancel()
}
s.mutex.Unlock()
return s.waitForUp(callerCtx)
}
defer s.mutex.Unlock()
if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil {
@@ -661,16 +642,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
if err != nil {
return nil, err
}
if status != internal.StatusIdle {
return nil, fmt.Errorf("up already in progress: current status %s", status)
}
// it should be nil here, but .
// it should be nil here, but in case it isn't we cancel it.
if s.actCancel != nil {
s.actCancel()
}
ctx, cancel := context.WithCancel(s.rootCtx)
md, ok := metadata.FromIncomingContext(callerCtx)
if ok {
ctx = metadata.NewOutgoingContext(ctx, md)
@@ -713,14 +694,23 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
s.clientRunning = true
s.clientRunningChan = make(chan struct{})
s.clientGiveUpChan = make(chan struct{})
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
return s.waitForUp(callerCtx)
}
// todo: handle potential race conditions
func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) {
timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
defer cancel()
runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan)
for {
select {
case <-runningChan:
case <-s.clientGiveUpChan:
return nil, fmt.Errorf("client gave up to connect")
case <-s.clientRunningChan:
s.isSessionActive.Store(true)
return &proto.UpResponse{}, nil
case <-callerCtx.Done():
@@ -730,7 +720,6 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
log.Debug("up is timed out, stopping the wait for engine to become ready")
return nil, timeoutCtx.Err()
}
}
}
func (s *Server) switchProfileIfNeeded(profileName string, userName *string, activeProf *profilemanager.ActiveProfileState) error {
@@ -1003,12 +992,46 @@ func (s *Server) Status(
ctx context.Context,
msg *proto.StatusRequest,
) (*proto.StatusResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
s.mutex.Lock()
clientRunning := s.clientRunning
s.mutex.Unlock()
if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning {
state := internal.CtxGetState(s.rootCtx)
status, err := state.Status()
if err != nil {
return nil, err
}
s.mutex.Lock()
defer s.mutex.Unlock()
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
s.actCancel()
}
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-s.clientGiveUpChan:
ticker.Stop()
break loop
case <-s.clientRunningChan:
ticker.Stop()
break loop
case <-ticker.C:
status, err := state.Status()
if err != nil {
continue
}
if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting {
s.actCancel()
}
continue
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
status, err := internal.CtxGetState(s.rootCtx).Status()
if err != nil {
@@ -1127,6 +1150,134 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
}, nil
}
// AddProfile adds a new profile to the daemon.
func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.ProfileName == "" || msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{}, nil
}
// RemoveProfile removes a profile from the daemon.
func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
}
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{}, nil
}
// ListProfiles lists all profiles in the daemon.
func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided")
}
profiles, err := s.profileManager.ListProfiles(msg.Username)
if err != nil {
log.Errorf("failed to list profiles: %v", err)
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
response := &proto.ListProfilesResponse{
Profiles: make([]*proto.Profile, len(profiles)),
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Name: profile.Name,
IsActive: profile.IsActive,
}
}
return response, nil
}
// GetActiveProfile returns the active profile in the daemon.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
activeProfile, err := s.profileManager.GetActiveProfileState()
if err != nil {
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
return &proto.GetActiveProfileResponse{
ProfileName: activeProfile.Name,
Username: activeProfile.Username,
}, nil
}
// GetFeatures returns the features supported by the daemon.
func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
}
return features, nil
}
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
if err := s.connectClient.Run(runningChan); err != nil {
return err
}
return nil
}
func (s *Server) checkProfilesDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.profilesDisabled {
return true
}
return false
}
func (s *Server) checkUpdateSettingsDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.updateSettingsDisabled {
return true
}
return false
}
func (s *Server) onSessionExpire() {
if runtime.GOOS != "windows" {
isUIActive := internal.CheckUIApp()
@@ -1138,6 +1289,45 @@ func (s *Server) onSessionExpire() {
}
}
// getConnectWithBackoff returns a backoff with exponential backoff strategy for connection retries
func getConnectWithBackoff(ctx context.Context) backoff.BackOff {
initialInterval := parseEnvDuration(retryInitialIntervalVar, defaultInitialRetryTime)
maxInterval := parseEnvDuration(maxRetryIntervalVar, defaultMaxRetryInterval)
maxElapsedTime := parseEnvDuration(maxRetryTimeVar, defaultMaxRetryTime)
multiplier := defaultRetryMultiplier
if envValue := os.Getenv(retryMultiplierVar); envValue != "" {
// parse the multiplier from the environment variable string value to float64
value, err := strconv.ParseFloat(envValue, 64)
if err != nil {
log.Warnf("unable to parse environment variable %s: %s. using default: %f", retryMultiplierVar, envValue, multiplier)
} else {
multiplier = value
}
}
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: initialInterval,
RandomizationFactor: 1,
Multiplier: multiplier,
MaxInterval: maxInterval,
MaxElapsedTime: maxElapsedTime, // 14 days
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// parseEnvDuration parses the environment variable and returns the duration
func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duration {
if envValue := os.Getenv(envVar); envValue != "" {
if duration, err := time.ParseDuration(envValue); err == nil {
return duration
}
log.Warnf("unable to parse environment variable %s: %s. using default: %s", envVar, envValue, defaultDuration)
}
return defaultDuration
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
@@ -1252,121 +1442,3 @@ func sendTerminalNotification() error {
return wallCmd.Wait()
}
// AddProfile adds a new profile to the daemon.
func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (*proto.AddProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.ProfileName == "" || msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{}, nil
}
// RemoveProfile removes a profile from the daemon.
func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
}
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{}, nil
}
// ListProfiles lists all profiles in the daemon.
func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesRequest) (*proto.ListProfilesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided")
}
profiles, err := s.profileManager.ListProfiles(msg.Username)
if err != nil {
log.Errorf("failed to list profiles: %v", err)
return nil, fmt.Errorf("failed to list profiles: %w", err)
}
response := &proto.ListProfilesResponse{
Profiles: make([]*proto.Profile, len(profiles)),
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Name: profile.Name,
IsActive: profile.IsActive,
}
}
return response, nil
}
// GetActiveProfile returns the active profile in the daemon.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
activeProfile, err := s.profileManager.GetActiveProfileState()
if err != nil {
log.Errorf("failed to get active profile state: %v", err)
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
return &proto.GetActiveProfileResponse{
ProfileName: activeProfile.Name,
Username: activeProfile.Username,
}, nil
}
// GetFeatures returns the features supported by the daemon.
func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) (*proto.GetFeaturesResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
features := &proto.GetFeaturesResponse{
DisableProfiles: s.checkProfilesDisabled(),
DisableUpdateSettings: s.checkUpdateSettingsDisabled(),
}
return features, nil
}
func (s *Server) checkProfilesDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.profilesDisabled {
return true
}
return false
}
func (s *Server) checkUpdateSettingsDisabled() bool {
// Check if the environment variable is set to disable profiles
if s.updateSettingsDisabled {
return true
}
return false
}

View File

@@ -18,6 +18,7 @@ import (
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -104,7 +105,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
@@ -133,8 +134,12 @@ func TestServer_Up(t *testing.T) {
profName := "default"
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
require.NoError(t, err)
ic := profilemanager.ConfigInput{
ConfigPath: filepath.Join(tempDir, profName+".json"),
ManagementURL: u.String(),
}
_, err = profilemanager.UpdateOrCreateConfig(ic)
@@ -152,16 +157,9 @@ func TestServer_Up(t *testing.T) {
}
s := New(ctx, "console", "", false, false)
err = s.Start()
require.NoError(t, err)
u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345")
require.NoError(t, err)
s.config = &profilemanager.Config{
ManagementURL: u,
}
upCtx, cancel := context.WithTimeout(ctx, 1*time.Second)
defer cancel()
@@ -170,6 +168,7 @@ func TestServer_Up(t *testing.T) {
Username: &currUser.Username,
}
_, err = s.Up(upCtx, upReq)
log.Errorf("error from Up: %v", err)
assert.Contains(t, err.Error(), "context deadline exceeded")
}

View File

@@ -563,6 +563,7 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
return
}
go func() {
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
@@ -583,7 +584,7 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
return
}
}
}()
}
},
OnCancel: func() {

View File

@@ -20,9 +20,9 @@ import (
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
nbgrpc "github.com/netbirdio/netbird/client/grpc"
"github.com/netbirdio/netbird/flow/proto"
"github.com/netbirdio/netbird/util/embeddedroots"
nbgrpc "github.com/netbirdio/netbird/util/grpc"
)
type GRPCClient struct {

2
go.mod
View File

@@ -261,6 +261,6 @@ replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-2
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107
replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944

4
go.sum
View File

@@ -501,8 +501,8 @@ github.com/neelance/astrewrite v0.0.0-20160511093645-99348263ae86/go.mod h1:kHJE
github.com/neelance/sourcemap v0.0.0-20200213170602-2833bce08e4c/go.mod h1:Qr6/a/Q4r9LP1IltGz7tA7iOK1WonHEYhu1HRBA7ZiM=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 h1:ZJwhKexMlK15B/Ld+1T8VYE2Mt1lk1kf2DlXr46EHcw=
github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=

View File

@@ -328,6 +328,45 @@ delete_auto_service_user() {
echo "$PARSED_RESPONSE"
}
delete_default_zitadel_admin() {
INSTANCE_URL=$1
PAT=$2
# Search for the default zitadel-admin user
RESPONSE=$(
curl -sS -X POST "$INSTANCE_URL/management/v1/users/_search" \
-H "Authorization: Bearer $PAT" \
-H "Content-Type: application/json" \
-d '{
"queries": [
{
"userNameQuery": {
"userName": "zitadel-admin@",
"method": "TEXT_QUERY_METHOD_STARTS_WITH"
}
}
]
}'
)
DEFAULT_ADMIN_ID=$(echo "$RESPONSE" | jq -r '.result[0].id // empty')
if [ -n "$DEFAULT_ADMIN_ID" ] && [ "$DEFAULT_ADMIN_ID" != "null" ]; then
echo "Found default zitadel-admin user with ID: $DEFAULT_ADMIN_ID"
RESPONSE=$(
curl -sS -X DELETE "$INSTANCE_URL/management/v1/users/$DEFAULT_ADMIN_ID" \
-H "Authorization: Bearer $PAT" \
-H "Content-Type: application/json" \
)
PARSED_RESPONSE=$(echo "$RESPONSE" | jq -r '.details.changeDate // "deleted"')
handle_zitadel_request_response "$PARSED_RESPONSE" "delete_default_zitadel_admin" "$RESPONSE"
else
echo "Default zitadel-admin user not found: $RESPONSE"
fi
}
init_zitadel() {
echo -e "\nInitializing Zitadel with NetBird's applications\n"
INSTANCE_URL="$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN"
@@ -346,6 +385,9 @@ init_zitadel() {
echo -n "Waiting for Zitadel to become ready "
wait_api "$INSTANCE_URL" "$PAT"
echo "Deleting default zitadel-admin user..."
delete_default_zitadel_admin "$INSTANCE_URL" "$PAT"
# create the zitadel project
echo "Creating new zitadel project"
PROJECT_ID=$(create_new_project "$INSTANCE_URL" "$PAT")

View File

@@ -20,29 +20,9 @@ import (
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
CustomZones sync.Map
NameServerGroups sync.Map
}
// GetCustomZone retrieves a cached custom zone
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
if c == nil {
return nil, false
}
if value, ok := c.CustomZones.Load(key); ok {
return value.(*proto.CustomZone), true
}
return nil, false
}
// SetCustomZone stores a custom zone in the cache
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
if c == nil {
return
}
c.CustomZones.Store(key, value)
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
@@ -212,15 +192,9 @@ func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSC
}
for _, zone := range update.CustomZones {
cacheKey := zone.Domain
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
} else {
protoZone := convertToProtoCustomZone(zone)
cache.SetCustomZone(cacheKey, protoZone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID

Some files were not shown because too many files have changed in this diff Show More