Compare commits

..

41 Commits

Author SHA1 Message Date
Zoltán Papp
6efc1a61fe Init static info in proper place 2025-09-04 16:44:58 +02:00
Zoltán Papp
4a45e40578 Cache os related information 2025-09-04 15:50:25 +02:00
Zoltán Papp
85ad236f6d Merge branch 'fix/missed-offers' into merged-fixes 2025-09-04 13:54:30 +02:00
Zoltán Papp
e54ab4b5e1 Merge branch 'debug-slow-conn' into merged-fixes 2025-09-04 13:54:13 +02:00
Zoltán Papp
23977c6409 Add debug ticker 2025-09-04 13:53:22 +02:00
Zoltán Papp
a73cb99045 Add async WireGuard handshake 2025-09-04 13:43:10 +02:00
Zoltán Papp
ccf40fefaf Fix proxy tests 2025-09-04 13:15:38 +02:00
Zoltán Papp
b2548a4037 Add verbose sys info gathering information 2025-09-04 12:11:57 +02:00
Zoltán Papp
180b52ba26 Merge branch 'main' into fix/pkg-loss 2025-09-03 16:27:17 +02:00
Zoltán Papp
f86a7f745e Fix logging library 2025-09-03 14:08:53 +02:00
Zoltán Papp
fd13247d66 Explicit print out the first success handshake time 2025-09-03 14:06:28 +02:00
Zoltán Papp
1d83fccd9c Improve logging for management service connection and system info gathering 2025-09-03 13:28:04 +02:00
Zoltán Papp
7d3c972653 Do not block Offer processing from relay worker
- do not miss ICE offers when relay worker busy
- close p2p connection before recreate agent
2025-09-02 22:17:13 +02:00
Zoltán Papp
990aa8f7cd Remove wg initiator logic 2025-02-24 10:46:22 +01:00
Zoltán Papp
2e6a44af1f Merge branch 'main' into fix/pkg-loss 2025-02-24 10:44:32 +01:00
Zoltán Papp
559d347588 Revert async WireGuard handshake 2025-02-21 16:19:28 +01:00
Zoltán Papp
06d71257b4 Build rawsocket code on linux only. 2025-02-21 16:00:31 +01:00
Zoltán Papp
62d10496ee Merge branch 'main' into fix/pkg-loss 2025-02-21 15:52:51 +01:00
Zoltán Papp
651e88d611 Eliminate code duplication 2025-02-21 14:59:42 +01:00
Zoltán Papp
d496d21693 Apply same pausedCond logic on all implementation 2025-02-21 14:55:56 +01:00
Zoltan Papp
648b4cdf72 Update client/iface/wgproxy/udp/proxy.go
Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2025-02-21 14:50:29 +01:00
Zoltán Papp
f0020ad4ce Fallback to package loss solution if the raw socket does not work. 2025-02-18 13:58:20 +01:00
Zoltán Papp
1eacff250e Remove WireGuard kernel code from FreeBSD 2025-02-18 13:47:42 +01:00
Zoltán Papp
1f83ba4563 Ignore err in tests 2025-02-17 22:06:04 +01:00
Zoltán Papp
3d80a25b4d Fix possible blocker if the bind will be closed earlier then proxy 2025-02-17 22:03:15 +01:00
Zoltán Papp
1963644c99 Add close test for all implementation 2025-02-17 21:47:34 +01:00
Zoltán Papp
360c7134f7 Add unit test 2025-02-17 21:26:37 +01:00
Zoltán Papp
775b4feb7e Fix close operation 2025-02-17 20:17:47 +01:00
Zoltán Papp
aca443bdec Build UDP proxy on Linux only 2025-02-17 19:26:40 +01:00
Zoltán Papp
335866ac60 Close unused rawsocket 2025-02-17 15:56:48 +01:00
Zoltán Papp
082452eb5f Add info log line 2025-02-17 15:50:49 +01:00
Zoltan Papp
b17c1d96a5 [client] Improve WireGuard handshake success rate (#3092)
The controller peer sends WireGuard handshake requests only
2025-02-17 15:50:15 +01:00
Zoltán Papp
2d5b5f59c2 Remove log line 2025-02-17 15:49:04 +01:00
Zoltán Papp
d5042f688f Fix interface changes 2025-02-17 15:47:20 +01:00
Zoltán Papp
4db73a13d7 Implement redirect logic in UDP proxy 2025-02-17 15:47:20 +01:00
Zoltán Papp
06a17f0eee Implement redirect to in eBPF proxy 2025-02-17 15:47:20 +01:00
Zoltán Papp
1f088b7e69 Extend the proxy interface with RedirectTo function and implement it in Bind proxy 2025-02-17 15:47:20 +01:00
Zoltan Papp
ffe74365a8 Code cleaning 2025-02-17 15:47:20 +01:00
Zoltán Papp
6a0f6efc18 Always stop timer 2025-02-17 15:47:20 +01:00
Zoltán Papp
bfa6df13c5 The non handshake initiator peer start the handshake after timeout 2025-02-17 15:47:20 +01:00
Zoltán Papp
e9b3b6210d Improve WireGuard handshake success rate
The controller peer sends WireGuard
handshake requests only
2025-02-17 15:47:20 +01:00
56 changed files with 1175 additions and 1113 deletions

View File

@@ -4,7 +4,6 @@ package android
import (
"context"
"os"
"slices"
"sync"
@@ -84,8 +83,7 @@ func NewClient(cfgFile string, androidSDKVersion int, deviceName string, uiVersi
}
// Run start the internal client. It is a blocker function
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
@@ -120,8 +118,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
// In this case make no sense handle registration steps.
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
exportEnvList(envList)
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener) error {
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
})
@@ -252,14 +249,3 @@ func (c *Client) SetConnectionListener(listener ConnectionListener) {
func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
func exportEnvList(list *EnvList) {
if list == nil {
return
}
for k, v := range list.AllItems() {
if err := os.Setenv(k, v); err != nil {
log.Errorf("could not set env variable %s: %v", k, err)
}
}
}

View File

@@ -1,32 +0,0 @@
package android
import "github.com/netbirdio/netbird/client/internal/peer"
var (
// EnvKeyNBForceRelay Exported for Android java client
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
)
// EnvList wraps a Go map for export to Java
type EnvList struct {
data map[string]string
}
// NewEnvList creates a new EnvList
func NewEnvList() *EnvList {
return &EnvList{data: make(map[string]string)}
}
// Put adds a key-value pair
func (el *EnvList) Put(key, value string) {
el.data[key] = value
}
// Get retrieves a value by key
func (el *EnvList) Get(key string) string {
return el.data[key]
}
func (el *EnvList) AllItems() map[string]string {
return el.data
}

View File

@@ -388,12 +388,12 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
}
func init() {
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 10, "Number of rotated log files to include in debug bundle")
debugBundleCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
debugBundleCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
debugBundleCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
debugBundleCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 10, "Number of rotated log files to include in debug bundle")
forCmd.Flags().Uint32VarP(&logFileCount, "log-file-count", "C", 1, "Number of rotated log files to include in debug bundle")
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")

View File

@@ -227,7 +227,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
}
// update host's static platform and system information
system.UpdateStaticInfoAsync()
system.UpdateStaticInfo()
configFilePath, err := activeProf.FilePath()
if err != nil {

View File

@@ -27,7 +27,7 @@ func (p *program) Start(svc service.Service) error {
log.Info("starting NetBird service") //nolint
// Collect static system and platform information
system.UpdateStaticInfoAsync()
system.UpdateStaticInfo()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer()

View File

@@ -9,26 +9,29 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/management/internals/server/config"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server"
mgmt "github.com/netbirdio/netbird/management/server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
sigProto "github.com/netbirdio/netbird/shared/signal/proto"
sig "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
)
func startTestingServices(t *testing.T) string {
@@ -87,20 +90,15 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
return nil, nil
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersmanager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
settingsMockManager.EXPECT().

View File

@@ -1,5 +1,17 @@
package bind
import wgConn "golang.zx2c4.com/wireguard/conn"
import (
"net"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type Endpoint = wgConn.StdNetEndpoint
func EndpointToUDPAddr(e Endpoint) *net.UDPAddr {
return &net.UDPAddr{
IP: e.Addr().AsSlice(),
Port: int(e.Port()),
Zone: e.Addr().Zone(),
}
}

View File

@@ -1,6 +1,7 @@
package bind
import (
"context"
"encoding/binary"
"fmt"
"net"
@@ -41,7 +42,7 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct {
*wgConn.StdNetBind
RecvChan chan RecvMessage
recvChan chan RecvMessage
transportNet transport.Net
filterFn FilterFn
@@ -64,7 +65,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
recvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
@@ -154,6 +155,14 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
return nil
}
func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) {
select {
case <-ctx.Done():
return
case b.recvChan <- msg:
}
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
@@ -270,7 +279,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
select {
case <-c.closedChan:
return 0, net.ErrClosed
case msg, ok := <-c.RecvChan:
case msg, ok := <-c.recvChan:
if !ok {
return 0, net.ErrClosed
}

View File

@@ -0,0 +1,40 @@
//go:build freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || freebsd
//go:build linux && !android
package iface

View File

@@ -16,28 +16,37 @@ import (
"github.com/netbirdio/netbird/client/iface/wgproxy/listener"
)
type IceBind interface {
SetEndpoint(fakeIP netip.Addr, conn net.Conn)
RemoveEndpoint(fakeIP netip.Addr)
Recv(ctx context.Context, msg bind.RecvMessage)
MTU() uint16
}
type ProxyBind struct {
Bind *bind.ICEBind
bind IceBind
fakeNetIP *netip.AddrPort
wgBindEndpoint *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
// wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address
wgRelayedEndpoint *bind.Endpoint
wgCurrentUsed *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
paused bool
pausedCond *sync.Cond
isStarted bool
closeListener *listener.CloseListener
}
func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
func NewProxyBind(bind IceBind) *ProxyBind {
p := &ProxyBind{
Bind: bind,
bind: bind,
closeListener: listener.NewCloseListener(),
pausedCond: sync.NewCond(&sync.Mutex{}),
}
return p
@@ -46,25 +55,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind {
// AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration.
//
// Parameters:
// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages
// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address
// - remoteConn: The established TURN connection to the remote peer
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
fakeNetIP, err := fakeAddress(nbAddr)
if err != nil {
return err
}
p.fakeNetIP = fakeNetIP
p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
return nil
}
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return &net.UDPAddr{
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint)
}
func (p *ProxyBind) SetDisconnectListener(disconnected func()) {
@@ -76,17 +85,22 @@ func (p *ProxyBind) Work() {
return
}
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn)
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = false
p.pausedMu.Unlock()
p.wgCurrentUsed = p.wgRelayedEndpoint
// Start the proxy only once
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
p.pausedCond.L.Unlock()
// todo: review to should be inside the lock scope
p.pausedCond.Signal()
}
func (p *ProxyBind) Pause() {
@@ -94,9 +108,25 @@ func (p *ProxyBind) Pause() {
return
}
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = true
p.pausedMu.Unlock()
p.pausedCond.L.Unlock()
}
func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgCurrentUsed = addrToEndpoint(endpoint)
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}
func (p *ProxyBind) CloseConn() error {
@@ -107,6 +137,10 @@ func (p *ProxyBind) CloseConn() error {
}
func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil
}
p.closeMu.Lock()
defer p.closeMu.Unlock()
@@ -120,7 +154,12 @@ func (p *ProxyBind) close() error {
p.cancel()
p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr())
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
@@ -136,7 +175,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}()
for {
buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead)
buf := make([]byte, p.bind.MTU()+bufsize.WGBufferOverhead)
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
@@ -147,18 +186,17 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
p.pausedCond.L.Lock()
for p.paused {
p.pausedCond.Wait()
}
msg := bind.RecvMessage{
Endpoint: p.wgBindEndpoint,
Endpoint: p.wgCurrentUsed,
Buffer: buf[:n],
}
p.Bind.RecvChan <- msg
p.pausedMu.Unlock()
p.bind.Recv(ctx, msg)
p.pausedCond.L.Unlock()
}
}

View File

@@ -6,9 +6,7 @@ import (
"context"
"fmt"
"net"
"os"
"sync"
"syscall"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
@@ -18,6 +16,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net"
@@ -27,6 +26,10 @@ const (
loopbackAddr = "127.0.0.1"
)
var (
localHostNetIP = net.ParseIP("127.0.0.1")
)
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
localWGListenPort int
@@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
return err
}
p.rawConn, err = p.prepareSenderRawSocket()
p.rawConn, err = rawsocket.PrepareSenderRawSocket()
if err != nil {
return err
}
@@ -214,57 +217,17 @@ generatePort:
return p.lastUsedPort, nil
}
func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}
func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
localhost := net.ParseIP("127.0.0.1")
func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error {
payload := gopacket.Payload(data)
ipH := &layers.IPv4{
DstIP: localhost,
SrcIP: localhost,
DstIP: localHostNetIP,
SrcIP: endpointAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(port),
SrcPort: layers.UDPPort(endpointAddr.Port),
DstPort: layers.UDPPort(p.localWGListenPort),
}
@@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error {
if err != nil {
return fmt.Errorf("serialize layers: %w", err)
}
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil {
if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil {
return fmt.Errorf("write to raw conn: %w", err)
}
return nil

View File

@@ -18,41 +18,42 @@ import (
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
wgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr
wgRelayedEndpointAddr *net.UDPAddr
wgEndpointCurrentUsedAddr *net.UDPAddr
pausedMu sync.Mutex
paused bool
isStarted bool
paused bool
pausedCond *sync.Cond
isStarted bool
closeListener *listener.CloseListener
}
func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper {
func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper {
return &ProxyWrapper{
WgeBPFProxy: WgeBPFProxy,
wgeBPFProxy: proxy,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(),
}
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)
}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgEndpointAddr = addr
p.wgRelayedEndpointAddr = addr
return err
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr
return p.wgRelayedEndpointAddr
}
func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) {
@@ -64,14 +65,19 @@ func (p *ProxyWrapper) Work() {
return
}
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = false
p.pausedMu.Unlock()
p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
p.pausedCond.L.Unlock()
// todo: review to should be inside the lock scope
p.pausedCond.Signal()
}
func (p *ProxyWrapper) Pause() {
@@ -80,45 +86,59 @@ func (p *ProxyWrapper) Pause() {
}
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = true
p.pausedMu.Unlock()
p.pausedCond.L.Unlock()
}
func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
p.paused = false
p.wgEndpointCurrentUsedAddr = endpoint
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error {
if e.cancel == nil {
func (p *ProxyWrapper) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
e.cancel()
p.cancel()
e.closeListener.SetCloseListener(nil)
p.closeListener.SetCloseListener(nil)
if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("close remote conn: %w", err)
p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port))
buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead)
buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead)
for {
n, err := p.readFromRemote(ctx, buf)
if err != nil {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
p.pausedCond.L.Lock()
for p.paused {
p.pausedCond.Wait()
}
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
p.pausedMu.Unlock()
err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr)
p.pausedCond.L.Unlock()
if err != nil {
if ctx.Err() != nil {
@@ -137,7 +157,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err
}
p.closeListener.Notify()
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err)
}
return 0, err
}

View File

@@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy {
}
return ebpf.NewProxyWrapper(w.ebpfProxy)
}
func (w *KernelFactory) Free() error {

View File

@@ -1,31 +0,0 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
mtu uint16
}
func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
mtu: mtu,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu)
}
func (w *KernelFactory) Free() error {
return nil
}

View File

@@ -11,6 +11,12 @@ type Proxy interface {
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
/*
RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused
and rewrite the src address to the endpoint address.
With this logic can avoid the package loss from relayed connections.
*/
RedirectAs(endpoint *net.UDPAddr)
CloseConn() error
SetDisconnectListener(disconnected func())
}

View File

@@ -3,54 +3,82 @@
package wgproxy
import (
"context"
"os"
"testing"
"fmt"
"net"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
if os.Getenv("GITHUB_ACTIONS") != "true" {
t.Skip("Skipping test as it requires root privileges")
}
ctx := context.Background()
func seedProxies() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
tests := []struct {
name string
proxy Proxy
}{
{
name: "ebpf proxy",
proxy: &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
},
},
pEbpf := proxyInstance{
name: "ebpf kernel proxy",
proxy: ebpf.NewProxyWrapper(ebpfProxy),
wgPort: 51831,
closeFn: ebpfProxy.Free,
}
pl = append(pl, pEbpf)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
pUDP := proxyInstance{
name: "udp kernel proxy",
proxy: udp.NewWGUDPProxy(51832, 1280),
wgPort: 51832,
closeFn: func() error { return nil },
}
pl = append(pl, pUDP)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err)
}
pEbpf := proxyInstance{
name: "ebpf kernel proxy",
proxy: ebpf.NewProxyWrapper(ebpfProxy),
wgPort: 51831,
closeFn: ebpfProxy.Free,
}
pl = append(pl, pEbpf)
pUDP := proxyInstance{
name: "udp kernel proxy",
proxy: udp.NewWGUDPProxy(51832, 1280),
wgPort: 51832,
closeFn: func() error { return nil },
}
pl = append(pl, pUDP)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
}

View File

@@ -0,0 +1,39 @@
//go:build !linux
package wgproxy
import (
"net"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/wgaddr"
bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
func seedProxies() ([]proxyInstance, error) {
// todo extend with Bind proxy
pl := make([]proxyInstance, 0)
return pl, nil
}
func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
pl := make([]proxyInstance, 0)
wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32")
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,
}
pBind := proxyInstance{
name: "bind proxy",
proxy: bindproxy.NewProxyBind(iceBind),
endpointAddr: endpointAddress,
closeFn: func() error { return nil },
}
pl = append(pl, pBind)
return pl, nil
}

View File

@@ -1,5 +1,3 @@
//go:build linux
package wgproxy
import (
@@ -7,12 +5,9 @@ import (
"io"
"net"
"os"
"runtime"
"testing"
"time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
"github.com/netbirdio/netbird/util"
)
@@ -22,6 +17,14 @@ func TestMain(m *testing.M) {
os.Exit(code)
}
type proxyInstance struct {
name string
proxy Proxy
wgPort int
endpointAddr *net.UDPAddr
closeFn func() error
}
type mocConn struct {
closeChan chan struct{}
closed bool
@@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error {
func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
proxy Proxy
}{
{
name: "userspace proxy",
proxy: udpProxy.NewWGUDPProxy(51830, 1280),
},
tests, err := seedProxyForProxyCloseByRemoteConn()
if err != nil {
t.Fatalf("error: %v", err)
}
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy)
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
defer func() {
_ = relayedConn.Close()
}()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892")
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
err := tt.proxy.AddTurnConn(ctx, addr, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
@@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
})
}
}
// TestProxyRedirect todo extend the proxies with Bind proxy
func TestProxyRedirect(t *testing.T) {
tests, err := seedProxies()
if err != nil {
t.Fatalf("error: %v", err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr)
if err := tt.closeFn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}
func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) {
t.Helper()
msgHelloFromRelay := []byte("hello from relay")
msgRedirected := [][]byte{
[]byte("hello 1. to p2p"),
[]byte("hello 2. to p2p"),
[]byte("hello 3. to p2p"),
}
dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: wgPort})
if err != nil {
t.Fatalf("failed to listen on udp port: %s", err)
}
relayedServer, _ := net.ListenUDP("udp",
&net.UDPAddr{
IP: net.IPv4(127, 0, 0, 1),
Port: 1234,
},
)
relayedConn, _ := net.Dial("udp", "127.0.0.1:1234")
defer func() {
_ = dummyWgListener.Close()
_ = relayedConn.Close()
_ = relayedServer.Close()
}()
if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil {
t.Errorf("error: %v", err)
}
defer func() {
if err := proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
}()
proxy.Work()
if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil {
t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err)
}
n, err := dummyWgListener.Read(make([]byte, 1024))
if err != nil {
t.Errorf("error: %v", err)
}
if n != len(msgHelloFromRelay) {
t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n)
}
p2pEndpointAddr := &net.UDPAddr{
IP: net.IPv4(192, 168, 0, 56),
Port: 1234,
}
proxy.RedirectAs(p2pEndpointAddr)
for _, msg := range msgRedirected {
if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil {
t.Errorf("error: %v", err)
}
}
for i := 0; i < len(msgRedirected); i++ {
buf := make([]byte, 1024)
n, rAddr, err := dummyWgListener.ReadFrom(buf)
if err != nil {
t.Errorf("error: %v", err)
}
if rAddr.String() != p2pEndpointAddr.String() {
t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String())
}
if string(buf[:n]) != string(msgRedirected[i]) {
t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n]))
}
}
}

View File

@@ -0,0 +1,50 @@
//go:build linux && !android
package rawsocket
import (
"fmt"
"net"
"os"
"syscall"
nbnet "github.com/netbirdio/netbird/util/net"
)
func PrepareSenderRawSocket() (net.PacketConn, error) {
// Create a raw socket.
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW)
if err != nil {
return nil, fmt.Errorf("creating raw socket failed: %w", err)
}
// Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet.
err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1)
if err != nil {
return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err)
}
// Bind the socket to the "lo" interface.
err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo")
if err != nil {
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
}
// Set the fwmark on the socket.
err = nbnet.SetSocketOpt(fd)
if err != nil {
return nil, fmt.Errorf("setting fwmark failed: %w", err)
}
// Convert the file descriptor to a PacketConn.
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
if file == nil {
return nil, fmt.Errorf("converting fd to file failed")
}
packetConn, err := net.FilePacketConn(file)
if err != nil {
return nil, fmt.Errorf("converting file to packet conn failed: %w", err)
}
return packetConn, nil
}

View File

@@ -1,3 +1,5 @@
//go:build linux && !android
package udp
import (
@@ -21,16 +23,18 @@ type WGUDPProxy struct {
localWGListenPort int
mtu uint16
remoteConn net.Conn
localConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
remoteConn net.Conn
localConn net.Conn
srcFakerConn *SrcFaker
sendPkg func(data []byte) (int, error)
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
paused bool
pausedCond *sync.Cond
isStarted bool
closeListener *listener.CloseListener
}
@@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy {
p := &WGUDPProxy{
localWGListenPort: wgPort,
mtu: mtu,
pausedCond: sync.NewCond(&sync.Mutex{}),
closeListener: listener.NewCloseListener(),
}
return p
@@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem
p.ctx, p.cancel = context.WithCancel(ctx)
p.localConn = localConn
p.sendPkg = p.localConn.Write
p.remoteConn = remoteConn
return err
@@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() {
return
}
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = false
p.pausedMu.Unlock()
p.sendPkg = p.localConn.Write
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
if !p.isStarted {
p.isStarted = true
go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx)
}
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
}
// Pause pauses the proxy from receiving data from the remote peer
@@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() {
return
}
p.pausedMu.Lock()
p.pausedCond.L.Lock()
p.paused = true
p.pausedMu.Unlock()
p.pausedCond.L.Unlock()
}
// RedirectAs start to use the fake sourced raw socket as package sender
func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Lock()
defer func() {
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
}()
p.paused = false
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
log.Errorf("failed to close src faker conn: %s", err)
}
p.srcFakerConn = nil
}
srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint)
if err != nil {
log.Errorf("failed to create src faker conn: %s", err)
// fallback to continue without redirecting
p.paused = true
return
}
p.srcFakerConn = srcFakerConn
p.sendPkg = p.srcFakerConn.SendPkg
}
// CloseConn close the localConn
@@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error {
}
func (p *WGUDPProxy) close() error {
var result *multierror.Error
p.closeMu.Lock()
defer p.closeMu.Unlock()
@@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error {
p.cancel()
var result *multierror.Error
p.pausedCond.L.Lock()
p.paused = false
p.pausedCond.L.Unlock()
p.pausedCond.Signal()
if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}
@@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error {
if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
}
if p.srcFakerConn != nil {
if err := p.srcFakerConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err))
}
}
return cerrors.FormatErrorOrNil(result)
}
@@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
p.pausedCond.L.Lock()
for p.paused {
p.pausedCond.Wait()
}
_, err = p.localConn.Write(buf[:n])
p.pausedMu.Unlock()
_, err = p.sendPkg(buf[:n])
p.pausedCond.L.Unlock()
if err != nil {
if ctx.Err() != nil {

View File

@@ -0,0 +1,101 @@
//go:build linux && !android
package udp
import (
"fmt"
"net"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket"
)
var (
serializeOpts = gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
localHostNetIPAddr = &net.IPAddr{
IP: net.ParseIP("127.0.0.1"),
}
)
type SrcFaker struct {
srcAddr *net.UDPAddr
rawSocket net.PacketConn
ipH gopacket.SerializableLayer
udpH gopacket.SerializableLayer
layerBuffer gopacket.SerializeBuffer
}
func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) {
rawSocket, err := rawsocket.PrepareSenderRawSocket()
if err != nil {
return nil, err
}
ipH, udpH, err := prepareHeaders(dstPort, srcAddr)
if err != nil {
return nil, err
}
f := &SrcFaker{
srcAddr: srcAddr,
rawSocket: rawSocket,
ipH: ipH,
udpH: udpH,
layerBuffer: gopacket.NewSerializeBuffer(),
}
return f, nil
}
func (f *SrcFaker) Close() error {
return f.rawSocket.Close()
}
func (f *SrcFaker) SendPkg(data []byte) (int, error) {
defer func() {
if err := f.layerBuffer.Clear(); err != nil {
log.Errorf("failed to clear layer buffer: %s", err)
}
}()
payload := gopacket.Payload(data)
err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload)
if err != nil {
return 0, fmt.Errorf("serialize layers: %w", err)
}
n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr)
if err != nil {
return 0, fmt.Errorf("write to raw conn: %w", err)
}
return n, nil
}
func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) {
ipH := &layers.IPv4{
DstIP: net.ParseIP("127.0.0.1"),
SrcIP: srcAddr.IP,
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolUDP,
}
udpH := &layers.UDP{
SrcPort: layers.UDPPort(srcAddr.Port),
DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port
}
err := udpH.SetNetworkLayerForChecksum(ipH)
if err != nil {
return nil, nil, fmt.Errorf("set network layer for checksum: %w", err)
}
return ipH, udpH, nil
}

View File

@@ -315,10 +315,6 @@ func (g *BundleGenerator) createArchive() error {
return fmt.Errorf("add sync response: %w", err)
}
if err := g.addDNSConfig(); err != nil {
log.Errorf("failed to add DNS config to debug bundle: %v", err)
}
if err := g.addStateFile(); err != nil {
log.Errorf("failed to add state file to debug bundle: %v", err)
}
@@ -345,50 +341,6 @@ func (g *BundleGenerator) createArchive() error {
return nil
}
// addDNSConfig writes a dns_config.json snapshot with routed domains and NS group status
func (g *BundleGenerator) addDNSConfig() error {
type nsGroup struct {
ID string `json:"id"`
Servers []string `json:"servers"`
Domains []string `json:"domains"`
Enabled bool `json:"enabled"`
Error string `json:"error,omitempty"`
}
type dnsConfig struct {
Groups []nsGroup `json:"name_server_groups"`
}
if g.statusRecorder == nil {
return nil
}
states := g.statusRecorder.GetDNSStates()
cfg := dnsConfig{Groups: make([]nsGroup, 0, len(states))}
for _, st := range states {
var servers []string
for _, ap := range st.Servers {
servers = append(servers, ap.String())
}
var errStr string
if st.Error != nil {
errStr = st.Error.Error()
}
cfg.Groups = append(cfg.Groups, nsGroup{
ID: st.ID,
Servers: servers,
Domains: st.Domains,
Enabled: st.Enabled,
Error: errStr,
})
}
bs, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("marshal dns config: %w", err)
}
return g.addFileToZip(bytes.NewReader(bs), "dns_config.json")
}
func (g *BundleGenerator) addSystemInfo() {
if err := g.addRoutes(); err != nil {
log.Errorf("failed to add routes to debug bundle: %v", err)

View File

@@ -46,18 +46,6 @@ type DNSForwarder struct {
fwdEntries []*ForwarderEntry
firewall firewaller
resolver resolver
// failure rate tracking for routed domains
failureMu sync.Mutex
failureCounts map[string]int
failureWindow time.Duration
lastLogPerHost map[string]time.Time
// per-domain rolling stats and windows
statsMu sync.Mutex
stats map[string]*domainStats
winSize time.Duration
slowT time.Duration
}
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
@@ -68,25 +56,9 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
firewall: firewall,
statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
failureCounts: make(map[string]int),
failureWindow: 10 * time.Second,
lastLogPerHost: make(map[string]time.Time),
stats: make(map[string]*domainStats),
winSize: 10 * time.Second,
slowT: 300 * time.Millisecond,
}
}
type domainStats struct {
total int
success int
timeouts int
notfound int
failures int // other failures (incl. SERVFAIL-like)
slow int
lastLog time.Time
}
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
@@ -191,19 +163,12 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
start := time.Now()
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
elapsed := time.Since(start)
if err != nil {
f.handleDNSError(ctx, w, question, resp, domain, err)
// record error stats for routed domains
f.recordErrorStats(strings.TrimSuffix(domain, "."), err)
return nil
}
// record success timing
f.recordSuccessStats(strings.TrimSuffix(domain, "."), elapsed)
f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips)
@@ -341,91 +306,6 @@ func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter,
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
}
// Track failure rate for routed domains only
if resID, _ := f.getMatchingEntries(strings.TrimSuffix(domain, ".")); resID != "" {
f.recordDomainFailure(strings.TrimSuffix(domain, "."))
}
}
// recordErrorStats updates per-domain counters and emits rate-limited logs
func (f *DNSForwarder) recordErrorStats(domain string, err error) {
domain = strings.ToLower(domain)
f.statsMu.Lock()
s := f.ensureStats(domain)
s.total++
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) {
if dnsErr.IsNotFound {
s.notfound++
} else if dnsErr.Timeout() {
s.timeouts++
} else {
s.failures++
}
} else {
s.failures++
}
f.maybeLogDomainStats(domain, s)
f.statsMu.Unlock()
}
// recordSuccessStats updates per-domain latency stats and slow counters, logs if needed (rate-limited)
func (f *DNSForwarder) recordSuccessStats(domain string, elapsed time.Duration) {
domain = strings.ToLower(domain)
f.statsMu.Lock()
s := f.ensureStats(domain)
s.total++
s.success++
if elapsed >= f.slowT {
s.slow++
}
f.maybeLogDomainStats(domain, s)
f.statsMu.Unlock()
}
func (f *DNSForwarder) ensureStats(domain string) *domainStats {
if ds, ok := f.stats[domain]; ok {
return ds
}
ds := &domainStats{}
f.stats[domain] = ds
return ds
}
// maybeLogDomainStats logs a compact summary per routed domain at most once per window
func (f *DNSForwarder) maybeLogDomainStats(domain string, s *domainStats) {
now := time.Now()
if !s.lastLog.IsZero() && now.Sub(s.lastLog) < f.winSize {
return
}
// check if routed (avoid logging for non-routed domains)
if resID, _ := f.getMatchingEntries(domain); resID == "" {
return
}
// only log if something noteworthy happened in the window
noteworthy := s.timeouts > 0 || s.notfound > 0 || s.failures > 0 || s.slow > 0
if !noteworthy {
s.lastLog = now
return
}
// warn on persistent problems, info otherwise
levelWarn := s.timeouts >= 3 || s.failures >= 3
if levelWarn {
log.Warnf("[d] DNS stats: domain=%s total=%d ok=%d timeout=%d nxdomain=%d fail=%d slow=%d(>=%s)",
domain, s.total, s.success, s.timeouts, s.notfound, s.failures, s.slow, f.slowT)
} else {
log.Infof("[d] DNS stats: domain=%s total=%d ok=%d timeout=%d nxdomain=%d fail=%d slow=%d(>=%s)",
domain, s.total, s.success, s.timeouts, s.notfound, s.failures, s.slow, f.slowT)
}
// reset counters for next window
*s = domainStats{lastLog: now}
}
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
@@ -461,27 +341,6 @@ func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []neti
}
}
// recordDomainFailure increments failure count for the domain and logs at info/warn with throttling.
func (f *DNSForwarder) recordDomainFailure(domain string) {
domain = strings.ToLower(domain)
f.failureMu.Lock()
defer f.failureMu.Unlock()
f.failureCounts[domain]++
count := f.failureCounts[domain]
now := time.Now()
last, ok := f.lastLogPerHost[domain]
if ok && now.Sub(last) < f.failureWindow {
return
}
f.lastLogPerHost[domain] = now
log.Warnf("[d] DNS failures observed for routed domain: domain=%s failures=%d/%s", domain, count, f.failureWindow)
}
// getMatchingEntries retrieves the resource IDs for a given domain.
// It returns the most specific match and all matching resource IDs.
func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*ForwarderEntry) {

View File

@@ -19,13 +19,17 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
@@ -41,12 +45,9 @@ import (
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -1554,11 +1555,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
@@ -1575,6 +1572,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)

View File

@@ -6,6 +6,7 @@ import (
"math/rand"
"net"
"net/netip"
"os"
"runtime"
"sync"
"time"
@@ -117,6 +118,8 @@ type Conn struct {
// debug purpose
dumpState *stateDump
endpointUpdater *endpointUpdater
}
// NewConn creates a new not opened Conn to the remote peer.
@@ -140,6 +143,11 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
statusRelay: worker.NewAtomicStatus(),
statusICE: worker.NewAtomicStatus(),
dumpState: newStateDump(config.Key, connLog, services.StatusRecorder),
endpointUpdater: &endpointUpdater{
log: connLog,
wgConfig: config.WgConfig,
initiator: isWireGuardInitiator(config),
},
}
return conn, nil
@@ -173,7 +181,7 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay)
conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() {
if os.Getenv("NB_FORCE_RELAY") != "true" {
conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
}
@@ -249,7 +257,7 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.wgProxyICE = nil
}
if err := conn.removeWgPeer(); err != nil {
if err := conn.endpointUpdater.removeWgPeer(); err != nil {
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
}
@@ -375,12 +383,18 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
wgProxy.Work()
}
if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String())
if err = conn.endpointUpdater.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
return
}
wgConfigWorkaround()
if conn.wgProxyRelay != nil {
conn.Log.Debugf("redirect packages from relayed conn to WireGuard")
conn.wgProxyRelay.RedirectAs(ep)
}
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo)
@@ -409,7 +423,7 @@ func (conn *Conn) onICEStateDisconnected() {
conn.dumpState.SwitchToRelay()
conn.wgProxyRelay.Work()
if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
if err := conn.endpointUpdater.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil {
conn.Log.Errorf("failed to switch to relay conn: %v", err)
}
@@ -418,6 +432,7 @@ func (conn *Conn) onICEStateDisconnected() {
defer conn.wgWatcherWg.Done()
conn.workerRelay.EnableWgWatcher(conn.ctx)
}()
conn.wgProxyRelay.Work()
conn.currentConnPriority = conntype.Relay
} else {
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
@@ -477,7 +492,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
}
wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := conn.endpointUpdater.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.Log.Warnf("Failed to close relay connection: %v", err)
}
@@ -545,17 +560,6 @@ func (conn *Conn) onGuardEvent() {
}
}
func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error {
presharedKey := conn.presharedKey(remoteRPKey)
return conn.config.WgConfig.WgInterface.UpdatePeer(
conn.config.WgConfig.RemoteKey,
conn.config.WgConfig.AllowedIps,
defaultWgKeepAlive,
addr,
presharedKey,
)
}
func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
peerState := State{
PubKey: conn.config.Key,
@@ -698,10 +702,6 @@ func (conn *Conn) isICEActive() bool {
return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected
}
func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
}
func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.Log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil {
@@ -782,6 +782,10 @@ func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}
func isWireGuardInitiator(config ConnConfig) bool {
return isController(config)
}
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil
}

View File

@@ -0,0 +1,88 @@
package peer
import (
"context"
"net"
"sync"
"time"
"github.com/sirupsen/logrus"
)
// fallbackDelay could be const but because of testing it is a var
var fallbackDelay = 5 * time.Second
type endpointUpdater struct {
log *logrus.Entry
wgConfig WgConfig
initiator bool
cancelFunc func()
configUpdateMutex sync.Mutex
}
// configureWGEndpoint sets up the WireGuard endpoint configuration.
// The initiator immediately configures the endpoint, while the non-initiator
// waits for a fallback period before configuring to avoid handshake congestion.
func (e *endpointUpdater) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error {
if e.initiator {
return e.updateWireGuardPeer(addr, remoteRPKey)
}
// prevent to run new update while cancel the previous update
e.configUpdateMutex.Lock()
if e.cancelFunc != nil {
e.cancelFunc()
}
e.configUpdateMutex.Unlock()
var ctx context.Context
ctx, e.cancelFunc = context.WithCancel(context.Background())
go e.scheduleDelayedUpdate(ctx, addr, remoteRPKey)
return e.updateWireGuardPeer(nil, remoteRPKey)
}
func (e *endpointUpdater) removeWgPeer() error {
e.configUpdateMutex.Lock()
defer e.configUpdateMutex.Unlock()
if e.cancelFunc != nil {
e.cancelFunc()
}
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
}
// scheduleDelayedUpdate waits for the fallback period before updating the endpoint
func (e *endpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, remoteRPKey []byte) {
t := time.NewTimer(fallbackDelay)
defer t.Stop()
select {
case <-ctx.Done():
return
case <-t.C:
e.configUpdateMutex.Lock()
defer e.configUpdateMutex.Unlock()
if ctx.Err() != nil {
return
}
if err := e.updateWireGuardPeer(addr, remoteRPKey); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
}
}
func (e *endpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, remoteRPKey []byte) error {
// todo add, "presharedKey := e.presharedKey(remote)"
return e.wgConfig.WgInterface.UpdatePeer(
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
endpoint,
e.wgConfig.PreSharedKey,
)
}

View File

@@ -1,14 +0,0 @@
package peer
import (
"os"
"strings"
)
const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
)
func isForceRelayed() bool {
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
}

View File

@@ -21,9 +21,9 @@ import (
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
)
const eventQueueSize = 10
@@ -201,8 +201,6 @@ type Status struct {
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
lazyConnectionEnabled bool
lastDisconnectLog map[string]time.Time
// To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
// set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications()
@@ -231,7 +229,6 @@ func NewRecorder(mgmAddress string) *Status {
notifier: newNotifier(),
mgmAddress: mgmAddress,
resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{},
lastDisconnectLog: make(map[string]time.Time),
}
}
@@ -490,9 +487,6 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
d.peers[receivedState.PubKey] = peerState
// info log about disconnect with impacted routes (throttled)
d.logPeerDisconnectIfNeeded(receivedState.PubKey, peerState)
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
@@ -525,9 +519,6 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.peers[receivedState.PubKey] = peerState
// info log about disconnect with impacted routes (throttled)
d.logPeerDisconnectIfNeeded(receivedState.PubKey, peerState)
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
d.notifyPeerListChanged()
}
@@ -538,49 +529,6 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
return nil
}
// logPeerDisconnectIfNeeded logs an info message when a routing peer transitions to disconnected
// with the number of impacted routes. Throttled to once per peer per 30 seconds.
func (d *Status) logPeerDisconnectIfNeeded(pubKey string, state State) {
if state.ConnStatus != StatusIdle {
return
}
now := time.Now()
last, ok := d.lastDisconnectLog[pubKey]
if ok && now.Sub(last) < 10*time.Second {
return
}
d.lastDisconnectLog[pubKey] = now
routes := state.GetRoutes()
numRoutes := len(routes)
fqdn := state.FQDN
if fqdn == "" {
fqdn = pubKey
}
// prepare a bounded list of impacted routes to avoid huge log lines
maxList := 20
list := make([]string, 0, maxList)
for r := range routes {
if len(list) >= maxList {
break
}
list = append(list, r)
}
more := ""
if numRoutes > len(list) {
more = ", more=" + fmt.Sprintf("%d", numRoutes-len(list))
}
if len(list) > 0 {
log.Warnf("[d] Routing peer disconnected: peer=%s impacted_routes=%d routes=%v%s", fqdn, numRoutes, list, more)
} else {
log.Warnf("[d] Routing peer disconnected: peer=%s impacted_routes=%d", fqdn, numRoutes)
}
}
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state
func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error {
d.mux.Lock()

View File

@@ -30,10 +30,10 @@ type WGWatcher struct {
peerKey string
stateDump *stateDump
ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
enabledTime time.Time
ctx context.Context
ctxCancel context.CancelFunc
ctxLock sync.Mutex
enabled time.Time
}
func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher {
@@ -49,7 +49,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) {
w.log.Debugf("enable WireGuard watcher")
w.ctxLock.Lock()
w.enabledTime = time.Now()
w.enabled = time.Now()
if w.ctx != nil && w.ctx.Err() == nil {
w.log.Errorf("WireGuard watcher already enabled")
@@ -89,6 +89,8 @@ func (w *WGWatcher) DisableWgWatcher() {
func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) {
w.log.Infof("WireGuard watcher started")
debugTicker := time.NewTicker(time.Second)
timer := time.NewTimer(wgHandshakeOvertime)
defer timer.Stop()
defer ctxCancel()
@@ -97,16 +99,28 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel contex
for {
select {
case <-debugTicker.C:
handshake, err := w.wgState()
if err != nil {
w.log.Errorf("failed to read wg stats: %v", err)
continue
}
if !handshake.IsZero() {
w.log.Infof("first wg handshake detected at: %s, %s", handshake, time.Since(w.enabled))
debugTicker.Stop()
}
case <-timer.C:
handshake, ok := w.handshakeCheck(lastHandshake)
if !ok {
onDisconnectedFn()
return
}
if lastHandshake.IsZero() {
elapsed := handshake.Sub(w.enabledTime).Seconds()
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
}
/*
// todo: put it back if remove debug ticker
if lastHandshake.IsZero() {
w.log.Infof("first wg handshake detected at: %s", handshake)
}
*/
lastHandshake = *handshake

View File

@@ -40,7 +40,7 @@ func NewNetWithDiscover(iFaceDiscover ExternalIFaceDiscover, disallowList []stri
if netstack.IsEnabled() {
n.iFaceDiscover = pionDiscover{}
} else {
n.iFaceDiscover = newMobileIFaceDiscover(iFaceDiscover)
newMobileIFaceDiscover(iFaceDiscover)
}
return n, n.UpdateInterfaces()
}

View File

@@ -10,24 +10,25 @@ import (
"time"
"github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
daemonProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
@@ -293,20 +294,15 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore)
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)

View File

@@ -96,6 +96,14 @@ func (i *Info) SetFlags(
i.LazyConnectionEnabled = lazyConnectionEnabled
}
// StaticInfo is an object that contains machine information that does not change
type StaticInfo struct {
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx)
@@ -191,3 +199,10 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
log.Debugf("all system information gathered successfully")
return info, nil
}
// UpdateStaticInfo asynchronously updates static system and platform information
func UpdateStaticInfo() {
go func() {
_ = updateStaticInfo()
}()
}

View File

@@ -15,11 +15,6 @@ import (
"github.com/netbirdio/netbird/version"
)
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
kernel := "android"

View File

@@ -19,10 +19,6 @@ import (
"github.com/netbirdio/netbird/version"
)
func UpdateStaticInfoAsync() {
go updateStaticInfo()
}
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
utsname := unix.Utsname{}
@@ -45,7 +41,7 @@ func GetInfo(ctx context.Context) *Info {
}
start := time.Now()
si := getStaticInfo()
si := updateStaticInfo()
if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start))
}

View File

@@ -18,11 +18,6 @@ import (
"github.com/netbirdio/netbird/version"
)
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
out := _getInfo()

View File

@@ -10,11 +10,6 @@ import (
"github.com/netbirdio/netbird/version"
)
// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update
func UpdateStaticInfoAsync() {
// do nothing
}
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {

View File

@@ -23,10 +23,6 @@ var (
getSystemInfo = defaultSysInfoImplementation
)
func UpdateStaticInfoAsync() {
go updateStaticInfo()
}
// GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info {
info := _getInfo()
@@ -52,7 +48,7 @@ func GetInfo(ctx context.Context) *Info {
}
start := time.Now()
si := getStaticInfo()
si := updateStaticInfo()
if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start))
}

View File

@@ -2,51 +2,246 @@ package system
import (
"context"
"fmt"
"os"
"runtime"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
"golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/version"
)
func UpdateStaticInfoAsync() {
go updateStaticInfo()
type Win32_OperatingSystem struct {
Caption string
}
// GetInfo retrieves and parses the system information
type Win32_ComputerSystem struct {
Manufacturer string
}
type Win32_ComputerSystemProduct struct {
Name string
}
type Win32_BIOS struct {
SerialNumber string
}
// CachedStaticInfo holds all the static system information that never changes
type CachedStaticInfo struct {
OSName string
OSVersion string
KernelVersion string
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment // Assuming this is from your StaticInfo struct
GoOS string
CPUs int
Kernel string
Platform string
}
var (
cachedStaticInfo *CachedStaticInfo
staticInfoOnce sync.Once
)
func init() {
go initStaticInfo()
}
// initStaticInfo initializes all static system information once
func initStaticInfo() {
staticInfoOnce.Do(func() {
log.Debugf("initializing static system information (one-time operation)")
start := time.Now()
// Get OS info
osName, osVersion := getOSNameAndVersion()
buildVersion := getBuildVersion()
// Get hardware info
si := updateStaticInfo()
cachedStaticInfo = &CachedStaticInfo{
OSName: osName,
OSVersion: osVersion,
KernelVersion: buildVersion,
SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment,
GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(),
Kernel: "windows",
Platform: "unknown",
}
log.Debugf("static system information initialized in %s", time.Since(start))
})
}
// GetInfo retrieves system information (static info cached, dynamic info fresh)
func GetInfo(ctx context.Context) *Info {
initStaticInfo()
log.Debugf("gathering dynamic system information")
start := time.Now()
si := getStaticInfo()
if time.Since(start) > 1*time.Second {
log.Warnf("updateStaticInfo took %s", time.Since(start))
}
gio := &Info{
Kernel: "windows",
OSVersion: si.OSVersion,
Platform: "unknown",
OS: si.OSName,
GoOS: runtime.GOOS,
CPUs: runtime.NumCPU(),
KernelVersion: si.BuildVersion,
SystemSerialNumber: si.SystemSerialNumber,
SystemProductName: si.SystemProductName,
SystemManufacturer: si.SystemManufacturer,
Environment: si.Environment,
}
// Only gather dynamic information that might change
log.Debugf("gathering networkAddresses")
addrs, err := networkAddresses()
if err != nil {
log.Warnf("failed to discover network addresses: %s", err)
} else {
gio.NetworkAddresses = addrs
}
log.Debugf("gathering Hostname")
systemHostname, _ := os.Hostname()
gio.Hostname = extractDeviceName(ctx, systemHostname)
gio.NetbirdVersion = version.NetbirdVersion()
gio.UIVersion = extractUserAgent(ctx)
// Create Info struct using cached static info + fresh dynamic info
gio := &Info{
// Static information (cached)
Kernel: cachedStaticInfo.Kernel,
OSVersion: cachedStaticInfo.OSVersion,
Platform: cachedStaticInfo.Platform,
OS: cachedStaticInfo.OSName,
GoOS: cachedStaticInfo.GoOS,
CPUs: cachedStaticInfo.CPUs,
KernelVersion: cachedStaticInfo.KernelVersion,
SystemSerialNumber: cachedStaticInfo.SystemSerialNumber,
SystemProductName: cachedStaticInfo.SystemProductName,
SystemManufacturer: cachedStaticInfo.SystemManufacturer,
Environment: cachedStaticInfo.Environment,
// Dynamic information (fresh each call)
NetworkAddresses: addrs,
Hostname: extractDeviceName(ctx, systemHostname),
NetbirdVersion: version.NetbirdVersion(), // This might change with updates
UIVersion: extractUserAgent(ctx), // This might change
}
log.Debugf("dynamic system information gathered in %s", time.Since(start))
return gio
}
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var err error
serialNumber, err = sysNumber()
if err != nil {
log.Warnf("failed to get system serial number: %s", err)
}
productName, err = sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err = sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
return serialNumber, productName, manufacturer
}
func getOSNameAndVersion() (string, string) {
var dst []Win32_OperatingSystem
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
log.Error(err)
return "Windows", getBuildVersion()
}
if len(dst) == 0 {
return "Windows", getBuildVersion()
}
split := strings.Split(dst[0].Caption, " ")
if len(split) <= 3 {
return "Windows", getBuildVersion()
}
name := split[1]
version := split[2]
if split[2] == "Server" {
name = fmt.Sprintf("%s %s", split[1], split[2])
version = split[3]
}
return name, version
}
func getBuildVersion() string {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
if err != nil {
log.Error(err)
return "0.0.0.0"
}
defer func() {
deferErr := k.Close()
if deferErr != nil {
log.Error(deferErr)
}
}()
major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
if err != nil {
log.Error(err)
}
minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber")
if err != nil {
log.Error(err)
}
build, _, err := k.GetStringValue("CurrentBuildNumber")
if err != nil {
log.Error(err)
}
// Update Build Revision
ubr, _, err := k.GetIntegerValue("UBR")
if err != nil {
log.Error(err)
}
ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr)
return ver
}
func sysNumber() (string, error) {
var dst []Win32_BIOS
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
return dst[0].SerialNumber, nil
}
func sysProductName() (string, error) {
var dst []Win32_ComputerSystemProduct
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
// `ComputerSystemProduct` could be empty on some virtualized systems
if len(dst) < 1 {
return "unknown", nil
}
return dst[0].Name, nil
}
func sysManufacturer() (string, error) {
var dst []Win32_ComputerSystem
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
return dst[0].Manufacturer, nil
}

View File

@@ -3,7 +3,12 @@
package system
import (
"context"
"sync"
"time"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
)
var (
@@ -11,26 +16,25 @@ var (
once sync.Once
)
// StaticInfo is an object that contains machine information that does not change
type StaticInfo struct {
SystemSerialNumber string
SystemProductName string
SystemManufacturer string
Environment Environment
// Windows specific fields
OSName string
OSVersion string
BuildVersion string
}
func updateStaticInfo() {
func updateStaticInfo() StaticInfo {
once.Do(func() {
staticInfo = newStaticInfo()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo()
wg.Done()
}()
go func() {
staticInfo.Environment.Cloud = detect_cloud.Detect(ctx)
wg.Done()
}()
go func() {
staticInfo.Environment.Platform = detect_platform.Detect(ctx)
wg.Done()
}()
wg.Wait()
})
}
func getStaticInfo() StaticInfo {
updateStaticInfo()
return staticInfo
}

View File

@@ -0,0 +1,8 @@
//go:build android || freebsd || ios
package system
// updateStaticInfo returns an empty implementation for unsupported platforms
func updateStaticInfo() StaticInfo {
return StaticInfo{}
}

View File

@@ -1,35 +0,0 @@
//go:build (linux && !android) || (darwin && !ios)
package system
import (
"context"
"sync"
"time"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
)
func newStaticInfo() StaticInfo {
si := StaticInfo{}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo()
wg.Done()
}()
go func() {
si.Environment.Cloud = detect_cloud.Detect(ctx)
wg.Done()
}()
go func() {
si.Environment.Platform = detect_platform.Detect(ctx)
wg.Done()
}()
wg.Wait()
return si
}

View File

@@ -1,184 +0,0 @@
package system
import (
"context"
"fmt"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
"golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/system/detect_cloud"
"github.com/netbirdio/netbird/client/system/detect_platform"
)
type Win32_OperatingSystem struct {
Caption string
}
type Win32_ComputerSystem struct {
Manufacturer string
}
type Win32_ComputerSystemProduct struct {
Name string
}
type Win32_BIOS struct {
SerialNumber string
}
func newStaticInfo() StaticInfo {
si := StaticInfo{}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
si.SystemSerialNumber, si.SystemProductName, si.SystemManufacturer = sysInfo()
wg.Done()
}()
wg.Add(1)
go func() {
si.Environment.Cloud = detect_cloud.Detect(ctx)
wg.Done()
}()
wg.Add(1)
go func() {
si.Environment.Platform = detect_platform.Detect(ctx)
wg.Done()
}()
wg.Add(1)
go func() {
si.OSName, si.OSVersion = getOSNameAndVersion()
wg.Done()
}()
wg.Add(1)
go func() {
si.BuildVersion = getBuildVersion()
wg.Done()
}()
wg.Wait()
return si
}
func sysInfo() (serialNumber string, productName string, manufacturer string) {
var err error
serialNumber, err = sysNumber()
if err != nil {
log.Warnf("failed to get system serial number: %s", err)
}
productName, err = sysProductName()
if err != nil {
log.Warnf("failed to get system product name: %s", err)
}
manufacturer, err = sysManufacturer()
if err != nil {
log.Warnf("failed to get system manufacturer: %s", err)
}
return serialNumber, productName, manufacturer
}
func sysNumber() (string, error) {
var dst []Win32_BIOS
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
return dst[0].SerialNumber, nil
}
func sysProductName() (string, error) {
var dst []Win32_ComputerSystemProduct
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
// `ComputerSystemProduct` could be empty on some virtualized systems
if len(dst) < 1 {
return "unknown", nil
}
return dst[0].Name, nil
}
func sysManufacturer() (string, error) {
var dst []Win32_ComputerSystem
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
return "", err
}
return dst[0].Manufacturer, nil
}
func getOSNameAndVersion() (string, string) {
var dst []Win32_OperatingSystem
query := wmi.CreateQuery(&dst, "")
err := wmi.Query(query, &dst)
if err != nil {
log.Error(err)
return "Windows", getBuildVersion()
}
if len(dst) == 0 {
return "Windows", getBuildVersion()
}
split := strings.Split(dst[0].Caption, " ")
if len(split) <= 3 {
return "Windows", getBuildVersion()
}
name := split[1]
version := split[2]
if split[2] == "Server" {
name = fmt.Sprintf("%s %s", split[1], split[2])
version = split[3]
}
return name, version
}
func getBuildVersion() string {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
if err != nil {
log.Error(err)
return "0.0.0.0"
}
defer func() {
deferErr := k.Close()
if deferErr != nil {
log.Error(deferErr)
}
}()
major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
if err != nil {
log.Error(err)
}
minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber")
if err != nil {
log.Error(err)
}
build, _, err := k.GetStringValue("CurrentBuildNumber")
if err != nil {
log.Error(err)
}
// Update Build Revision
ubr, _, err := k.GetIntegerValue("UBR")
if err != nil {
log.Error(err)
}
ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr)
return ver
}

2
go.mod
View File

@@ -62,7 +62,7 @@ require (
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 h1:ZJwhKexMlK15B/Ld+1T8VYE2Mt1lk1kf2DlXr46EHcw=
github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190 h1:/ZbExdcDwRq6XgTpTf5I1DPqnC3eInEf0fcmkqR8eSg=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250820151658-9ee1b34f4190/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=

View File

@@ -20,11 +20,7 @@ func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator {
integratedPeerValidator, err := integrations.NewIntegratedValidator(
context.Background(),
s.PeersManager(),
s.SettingsManager(),
s.EventStore())
integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore())
if err != nil {
log.Errorf("failed to create integrated peer validator: %v", err)
}

View File

@@ -1714,9 +1714,7 @@ func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, account
log.WithContext(ctx).Errorf("failed to get invalidated peer %s for account %s: %v", peerID, accountID, err)
continue
}
if peer.UserID != "" {
peers = append(peers, peer)
}
peers = append(peers, peer)
}
if len(peers) > 0 {
err := am.expireAndUpdatePeers(ctx, accountID, peers)

View File

@@ -18,7 +18,6 @@ type Manager interface {
GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
GetPeerAccountID(ctx context.Context, peerID string) (string, error)
GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
}
type managerImpl struct {
@@ -62,7 +61,3 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
}
func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
}

View File

@@ -79,18 +79,3 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
}
// GetPeersByGroupIDs mocks base method.
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeersByGroupIDs", ctx, accountID, groupsIDs)
ret0, _ := ret[0].([]*peer.Peer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeersByGroupIDs indicates an expected call of GetPeersByGroupIDs.
func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs)
}

View File

@@ -2847,22 +2847,3 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i
}
return nil
}
func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) {
if len(groupIDs) == 0 {
return []*nbpeer.Peer{}, nil
}
var peers []*nbpeer.Peer
peerIDsSubquery := s.db.Model(&types.GroupPeer{}).
Select("DISTINCT peer_id").
Where("account_id = ? AND group_id IN ?", accountID, groupIDs)
result := s.db.Where("id IN (?)", peerIDsSubquery).Find(&peers)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get peers by group IDs: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peers by group IDs")
}
return peers, nil
}

View File

@@ -3607,113 +3607,3 @@ func intToIPv4(n uint32) net.IP {
binary.BigEndian.PutUint32(ip, n)
return ip
}
func TestSqlStore_GetPeersByGroupIDs(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
group1ID := "test-group-1"
group2ID := "test-group-2"
emptyGroupID := "empty-group"
peer1 := "cfefqs706sqkneg59g4g"
peer2 := "cfeg6sf06sqkneg59g50"
tests := []struct {
name string
groupIDs []string
expectedPeers []string
expectedCount int
}{
{
name: "retrieve peers from single group with multiple peers",
groupIDs: []string{group1ID},
expectedPeers: []string{peer1, peer2},
expectedCount: 2,
},
{
name: "retrieve peers from single group with one peer",
groupIDs: []string{group2ID},
expectedPeers: []string{peer1},
expectedCount: 1,
},
{
name: "retrieve peers from multiple groups (with overlap)",
groupIDs: []string{group1ID, group2ID},
expectedPeers: []string{peer1, peer2}, // should deduplicate
expectedCount: 2,
},
{
name: "retrieve peers from existing 'All' group",
groupIDs: []string{"cfefqs706sqkneg59g3g"}, // All group from test data
expectedPeers: []string{peer1, peer2},
expectedCount: 2,
},
{
name: "retrieve peers from empty group",
groupIDs: []string{emptyGroupID},
expectedPeers: []string{},
expectedCount: 0,
},
{
name: "retrieve peers from non-existing group",
groupIDs: []string{"non-existing-group"},
expectedPeers: []string{},
expectedCount: 0,
},
{
name: "empty group IDs list",
groupIDs: []string{},
expectedPeers: []string{},
expectedCount: 0,
},
{
name: "mix of existing and non-existing groups",
groupIDs: []string{group1ID, "non-existing-group"},
expectedPeers: []string{peer1, peer2},
expectedCount: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
ctx := context.Background()
groups := []*types.Group{
{
ID: group1ID,
AccountID: accountID,
},
{
ID: group2ID,
AccountID: accountID,
},
}
require.NoError(t, store.CreateGroups(ctx, accountID, groups))
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group1ID))
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer2, group1ID))
require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group2ID))
peers, err := store.GetPeersByGroupIDs(ctx, accountID, tt.groupIDs)
require.NoError(t, err)
require.Len(t, peers, tt.expectedCount)
if tt.expectedCount > 0 {
actualPeerIDs := make([]string, len(peers))
for i, peer := range peers {
actualPeerIDs[i] = peer.ID
}
assert.ElementsMatch(t, tt.expectedPeers, actualPeerIDs)
// Verify all returned peers belong to the correct account
for _, peer := range peers {
assert.Equal(t, accountID, peer.AccountID)
}
}
})
}
}

View File

@@ -136,7 +136,6 @@ type Store interface {
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)

View File

@@ -942,11 +942,6 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key)
if peer.UserID == "" {
// we do not want to expire peers that are added via setup key
continue
}
if peer.Status.LoginExpired {
continue
}

View File

@@ -130,6 +130,36 @@ repo_gpgcheck=1
EOF
}
install_aur_package() {
INSTALL_PKGS="git base-devel go"
REMOVE_PKGS=""
# Check if dependencies are installed
for PKG in $INSTALL_PKGS; do
if ! pacman -Q "$PKG" > /dev/null 2>&1; then
# Install missing package(s)
${SUDO} pacman -S "$PKG" --noconfirm
# Add installed package for clean up later
REMOVE_PKGS="$REMOVE_PKGS $PKG"
fi
done
# Build package from AUR
cd /tmp && git clone https://aur.archlinux.org/netbird.git
cd netbird && makepkg -sri --noconfirm
if ! $SKIP_UI_APP; then
cd /tmp && git clone https://aur.archlinux.org/netbird-ui.git
cd netbird-ui && makepkg -sri --noconfirm
fi
if [ -n "$REMOVE_PKGS" ]; then
# Clean up the installed packages
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
fi
}
prepare_tun_module() {
# Create the necessary file structure for /dev/net/tun
if [ ! -c /dev/net/tun ]; then
@@ -246,9 +276,12 @@ install_netbird() {
if ! $SKIP_UI_APP; then
${SUDO} rpm-ostree -y install netbird-ui
fi
# ensure the service is started after install
${SUDO} netbird service install || true
${SUDO} netbird service start || true
;;
pacman)
${SUDO} pacman -Syy
install_aur_package
# in-line with the docs at https://wiki.archlinux.org/title/Netbird
${SUDO} systemctl enable --now netbird@main.service
;;
pkg)
# Check if the package is already installed
@@ -425,7 +458,11 @@ if type uname >/dev/null 2>&1; then
elif [ -x "$(command -v yum)" ]; then
PACKAGE_MANAGER="yum"
echo "The installation will be performed using yum package manager"
elif [ -x "$(command -v pacman)" ]; then
PACKAGE_MANAGER="pacman"
echo "The installation will be performed using pacman package manager"
fi
else
echo "Unable to determine OS type from /etc/os-release"
exit 1

View File

@@ -9,30 +9,34 @@ import (
"time"
"github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/internals/server/config"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/peers"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/util"
)
@@ -68,31 +72,13 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
permissionsManagerMock.
EXPECT().
ValidateUserPermissions(
gomock.Any(),
gomock.Any(),
gomock.Any(),
gomock.Any(),
gomock.Any(),
).
Return(true, nil).
AnyTimes()
peersManger := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore)
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.
EXPECT().
@@ -109,6 +95,19 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManagerMock := permissions.NewMockManager(ctrl)
permissionsManagerMock.
EXPECT().
ValidateUserPermissions(
gomock.Any(),
gomock.Any(),
gomock.Any(),
gomock.Any(),
gomock.Any(),
).
Return(true, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)

View File

@@ -14,7 +14,7 @@ import (
"github.com/netbirdio/netbird/formatter"
)
const defaultLogSize = 100
const defaultLogSize = 15
const (
LogConsole = "console"