Compare commits

...

7 Commits

Author SHA1 Message Date
Viktor Liu
063cbdc6d8 Enable trace logging in WASM client 2026-01-16 15:06:58 +08:00
Maycon Santos
291e640b28 [client] Change priority between local and dns route handlers (#5106)
* Change priority between local and dns route handlers

* update priority tests
2026-01-15 17:30:10 +01:00
Pascal Fischer
efb954b7d6 [management] adapt ratelimiting (#5080) 2026-01-15 16:39:14 +01:00
Vlad
cac9326d3d [management] fetch all users data from external cache in one request (#5104)
---------

Co-authored-by: pascal <pascal@netbird.io>
2026-01-14 17:09:17 +01:00
Viktor Liu
520d9c66cf [client] Fix netstack upstream dns and add wasm debug methods (#4648) 2026-01-14 13:56:16 +01:00
Misha Bragin
ff10498a8b Feature/embedded STUN (#5062) 2026-01-14 13:13:30 +01:00
Zoltan Papp
00b747ad5d Handle fallback for invalid loginuid in ui-post-install.sh. (#5099) 2026-01-14 09:53:14 +01:00
37 changed files with 1881 additions and 335 deletions

View File

@@ -314,9 +314,8 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
profName = activeProf.Name profName = activeProf.Name
} }
statusOutputString = nbstatus.ParseToFullDetailSummary( overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), statusOutputString = overview.FullDetailSummary()
)
} }
return statusOutputString return statusOutputString
} }

View File

@@ -103,13 +103,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
var statusOutputString string var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder) statusOutputString = outputInformationHolder.FullDetailSummary()
case jsonFlag: case jsonFlag:
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder) statusOutputString, err = outputInformationHolder.JSON()
case yamlFlag: case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder) statusOutputString, err = outputInformationHolder.YAML()
default: default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false) statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
} }
if err != nil { if err != nil {

View File

@@ -10,6 +10,7 @@ import (
"net/netip" "net/netip"
"os" "os"
"sync" "sync"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack" wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
@@ -20,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh" sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
) )
var ( var (
@@ -29,6 +31,11 @@ var (
ErrConfigNotInitialized = errors.New("config not initialized") ErrConfigNotInitialized = errors.New("config not initialized")
) )
const (
defaultPeerConnectionTimeout = 60 * time.Second
peerConnectionPollInterval = 500 * time.Millisecond
)
// Client manages a netbird embedded client instance. // Client manages a netbird embedded client instance.
type Client struct { type Client struct {
deviceName string deviceName string
@@ -38,6 +45,7 @@ type Client struct {
setupKey string setupKey string
jwtToken string jwtToken string
connect *internal.ConnectClient connect *internal.ConnectClient
recorder *peer.Status
} }
// Options configures a new Client. // Options configures a new Client.
@@ -161,11 +169,17 @@ func New(opts Options) (*Client, error) {
func (c *Client) Start(startCtx context.Context) error { func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.cancel != nil { if c.connect != nil {
return ErrClientAlreadyStarted return ErrClientAlreadyStarted
} }
ctx := internal.CtxInitState(context.Background()) ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
defer func() {
if c.connect == nil {
cancel()
}
}()
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil { if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
@@ -173,7 +187,9 @@ func (c *Client) Start(startCtx context.Context) error {
} }
recorder := peer.NewRecorder(c.config.ManagementURL.String()) recorder := peer.NewRecorder(c.config.ManagementURL.String())
c.recorder = recorder
client := internal.NewConnectClient(ctx, c.config, recorder, false) client := internal.NewConnectClient(ctx, c.config, recorder, false)
client.SetSyncResponsePersistence(true)
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available // TODO: make after-startup backoff err available
@@ -197,6 +213,7 @@ func (c *Client) Start(startCtx context.Context) error {
} }
c.connect = client c.connect = client
c.cancel = cancel
return nil return nil
} }
@@ -211,17 +228,23 @@ func (c *Client) Stop(ctx context.Context) error {
return ErrClientNotStarted return ErrClientNotStarted
} }
if c.cancel != nil {
c.cancel()
c.cancel = nil
}
done := make(chan error, 1) done := make(chan error, 1)
connect := c.connect
go func() { go func() {
done <- c.connect.Stop() done <- connect.Stop()
}() }()
select { select {
case <-ctx.Done(): case <-ctx.Done():
c.cancel = nil c.connect = nil
return ctx.Err() return ctx.Err()
case err := <-done: case err := <-done:
c.cancel = nil c.connect = nil
if err != nil { if err != nil {
return fmt.Errorf("stop: %w", err) return fmt.Errorf("stop: %w", err)
} }
@@ -241,18 +264,40 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
// Dial dials a network address in the netbird network. // Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled. // Not applicable if the userspace networking mode is disabled.
// With lazy connections, the connection is established on first traffic.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
logrus.Infof("embed.Dial called: network=%s, address=%s", network, address)
// Check context status upfront
if ctx.Err() != nil {
logrus.Warnf("embed.Dial: context already cancelled/expired: %v", ctx.Err())
return nil, ctx.Err()
}
engine, err := c.getEngine() engine, err := c.getEngine()
if err != nil { if err != nil {
logrus.Errorf("embed.Dial: getEngine failed: %v", err)
return nil, err return nil, err
} }
nsnet, err := engine.GetNet() nsnet, err := engine.GetNet()
if err != nil { if err != nil {
logrus.Errorf("embed.Dial: GetNet failed: %v", err)
return nil, fmt.Errorf("get net: %w", err) return nil, fmt.Errorf("get net: %w", err)
} }
return nsnet.DialContext(ctx, network, address) // Note: Don't wait for peer connection here - lazy connection manager
// will open the connection when DialContext is called. The netstack
// dial triggers WireGuard traffic which activates the lazy connection.
logrus.Debugf("embed.Dial: calling nsnet.DialContext for %s", address)
conn, err := nsnet.DialContext(ctx, network, address)
if err != nil {
logrus.Errorf("embed.Dial: nsnet.DialContext failed: %v", err)
return nil, err
}
logrus.Infof("embed.Dial: successfully connected to %s", address)
return conn, nil
} }
// DialContext dials a network address in the netbird network with context // DialContext dials a network address in the netbird network with context
@@ -315,6 +360,90 @@ func (c *Client) NewHTTPClient() *http.Client {
} }
} }
// Status returns the current status of the client.
func (c *Client) Status() (peer.FullStatus, error) {
c.mu.Lock()
recorder := c.recorder
connect := c.connect
c.mu.Unlock()
if recorder == nil {
return peer.FullStatus{}, errors.New("client not started")
}
if connect != nil {
engine := connect.Engine()
if engine != nil {
_ = engine.RunHealthProbes(false)
}
}
return recorder.GetFullStatus(), nil
}
// GetLatestSyncResponse returns the latest sync response from the management server.
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
syncResp, err := engine.GetLatestSyncResponse()
if err != nil {
return nil, fmt.Errorf("get sync response: %w", err)
}
return syncResp, nil
}
// WaitForPeerConnection waits for a peer with the given IP to be connected.
func (c *Client) WaitForPeerConnection(ctx context.Context, peerIP string) error {
logrus.Infof("Waiting for peer %s to be connected", peerIP)
ticker := time.NewTicker(peerConnectionPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("timeout waiting for peer %s to connect: %w", peerIP, ctx.Err())
case <-ticker.C:
status, err := c.Status()
if err != nil {
logrus.Debugf("Error getting status while waiting for peer: %v", err)
continue
}
for _, p := range status.Peers {
if p.IP == peerIP && p.ConnStatus == peer.StatusConnected {
logrus.Infof("Peer %s is now connected (relayed: %v)", peerIP, p.Relayed)
return nil
}
}
logrus.Tracef("Peer %s not yet connected, waiting...", peerIP)
}
}
}
// SetLogLevel sets the logging level for the client and its components.
func (c *Client) SetLogLevel(levelStr string) error {
level, err := logrus.ParseLevel(levelStr)
if err != nil {
return fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
c.mu.Lock()
connect := c.connect
c.mu.Unlock()
// Note: ConnectClient doesn't have SetLogLevel method
_ = connect
return nil
}
// VerifySSHHostKey verifies an SSH host key against stored peer keys. // VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network, // Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures. // ErrNoStoredKey if peer has no stored key, or an error for verification failures.

View File

@@ -23,10 +23,10 @@ func NewNSDialer(net *netstack.Net) *NSDialer {
} }
func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
log.Debugf("dialing %s %s", network, addr) log.Infof("NSDialer.Dial: network=%s, addr=%s", network, addr)
conn, err := d.net.Dial(network, addr) conn, err := d.net.Dial(network, addr)
if err != nil { if err != nil {
log.Debugf("failed to deal connection: %s", err) log.Warnf("NSDialer.Dial failed: %s", err)
} }
return conn, err return conn, err
} }

View File

@@ -420,6 +420,19 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
return syncResponse, nil return syncResponse, nil
} }
// SetLogLevel sets the log level for the firewall manager if the engine is running.
func (c *ConnectClient) SetLogLevel(level log.Level) {
engine := c.Engine()
if engine == nil {
return
}
fwManager := engine.GetFirewallManager()
if fwManager != nil {
fwManager.SetLogLevel(level)
}
}
// Status returns the current client status // Status returns the current client status
func (c *ConnectClient) Status() StatusType { func (c *ConnectClient) Status() StatusType {
if c == nil { if c == nil {

View File

@@ -16,8 +16,8 @@ import (
const ( const (
PriorityMgmtCache = 150 PriorityMgmtCache = 150
PriorityLocal = 100 PriorityDNSRoute = 100
PriorityDNSRoute = 75 PriorityLocal = 75
PriorityUpstream = 50 PriorityUpstream = 50
PriorityDefault = 1 PriorityDefault = 1
PriorityFallback = -100 PriorityFallback = -100

View File

@@ -631,9 +631,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
handler, err := newUpstreamResolver( handler, err := newUpstreamResolver(
s.ctx, s.ctx,
s.wgInterface.Name(), s.wgInterface,
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder, s.hostsDNSHolder,
nbdns.RootZone, nbdns.RootZone,
@@ -743,9 +741,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority) log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
handler, err := newUpstreamResolver( handler, err := newUpstreamResolver(
s.ctx, s.ctx,
s.wgInterface.Name(), s.wgInterface,
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder, s.hostsDNSHolder,
domainGroup.domain, domainGroup.domain,
@@ -926,9 +922,7 @@ func (s *DefaultServer) addHostRootZone() {
handler, err := newUpstreamResolver( handler, err := newUpstreamResolver(
s.ctx, s.ctx,
s.wgInterface.Name(), s.wgInterface,
s.wgInterface.Address().IP,
s.wgInterface.Address().Network,
s.statusRecorder, s.statusRecorder,
s.hostsDNSHolder, s.hostsDNSHolder,
nbdns.RootZone, nbdns.RootZone,

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
@@ -81,6 +82,10 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
return configurer.WGStats{}, nil return configurer.WGStats{}, nil
} }
func (w *mocWGIface) GetNet() *netstack.Net {
return nil
}
var zoneRecords = []nbdns.SimpleRecord{ var zoneRecords = []nbdns.SimpleRecord{
{ {
Name: "peera.netbird.cloud", Name: "peera.netbird.cloud",
@@ -180,7 +185,7 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}, expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
}, },
{ {
name: "New Config Should Succeed", name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}}, initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
@@ -221,19 +226,19 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
}, },
{ {
name: "Smaller Config Serial Should Be Skipped", name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{}, initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 2, initSerial: 2,
inputSerial: 1, inputSerial: 1,
shouldFail: true, shouldFail: true,
}, },
{ {
name: "Empty NS Group Domain Or Not Primary Element Should Fail", name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{}, initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
ServiceEnable: true, ServiceEnable: true,
CustomZones: []nbdns.CustomZone{ CustomZones: []nbdns.CustomZone{
@@ -251,11 +256,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true, shouldFail: true,
}, },
{ {
name: "Invalid NS Group Nameservers list Should Fail", name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{}, initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
ServiceEnable: true, ServiceEnable: true,
CustomZones: []nbdns.CustomZone{ CustomZones: []nbdns.CustomZone{
@@ -273,11 +278,11 @@ func TestUpdateDNSServer(t *testing.T) {
shouldFail: true, shouldFail: true,
}, },
{ {
name: "Invalid Custom Zone Records list Should Skip", name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{}, initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: make(registeredHandlerMap),
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
ServiceEnable: true, ServiceEnable: true,
CustomZones: []nbdns.CustomZone{ CustomZones: []nbdns.CustomZone{
@@ -299,7 +304,7 @@ func TestUpdateDNSServer(t *testing.T) {
}}, }},
}, },
{ {
name: "Empty Config Should Succeed and Clean Maps", name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}}, initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
@@ -315,7 +320,7 @@ func TestUpdateDNSServer(t *testing.T) {
expectedLocalQs: []dns.Question{}, expectedLocalQs: []dns.Question{},
}, },
{ {
name: "Disabled Service Should clean map", name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}}, initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: registeredHandlerMap{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
@@ -2047,7 +2052,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
func TestLocalResolverPriorityConstants(t *testing.T) { func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly // Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route") assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream") assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default") assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")

View File

@@ -18,6 +18,7 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/resutil"
@@ -418,6 +419,56 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil return rm, t, nil
} }
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
if err != nil {
return nil, err
}
// If response is truncated, retry with TCP
if reply != nil && reply.MsgHdr.Truncated {
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
}
return reply, nil
}
func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) {
conn, err := nsNet.DialContext(ctx, network, upstream)
if err != nil {
return nil, fmt.Errorf("with %s: %w", network, err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close DNS connection: %v", err)
}
}()
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetDeadline(deadline); err != nil {
return nil, fmt.Errorf("set deadline: %w", err)
}
}
dnsConn := &dns.Conn{Conn: conn}
if err := dnsConn.WriteMsg(r); err != nil {
return nil, fmt.Errorf("write %s message: %w", network, err)
}
reply, err := dnsConn.ReadMsg()
if err != nil {
return nil, fmt.Errorf("read %s message: %w", network, err)
}
return reply, nil
}
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts // FormatPeerStatus formats peer connection status information for debugging DNS timeouts
func FormatPeerStatus(peerState *peer.State) string { func FormatPeerStatus(peerState *peer.State) string {
isConnected := peerState.ConnStatus == peer.StatusConnected isConnected := peerState.ConnStatus == peer.StatusConnected

View File

@@ -23,9 +23,7 @@ type upstreamResolver struct {
// first time, and we need to wait for a while to start to use again the proper DNS resolver. // first time, and we need to wait for a while to start to use again the proper DNS resolver.
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, _ WGIface,
_ netip.Addr,
_ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder, hostsDNSHolder *hostsDNSHolder,
domain string, domain string,

View File

@@ -5,22 +5,23 @@ package dns
import ( import (
"context" "context"
"net/netip" "net/netip"
"runtime"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
) )
type upstreamResolver struct { type upstreamResolver struct {
*upstreamResolverBase *upstreamResolverBase
nsNet *netstack.Net
} }
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, wgIface WGIface,
_ netip.Addr,
_ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain string,
@@ -28,12 +29,23 @@ func newUpstreamResolver(
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
nonIOS := &upstreamResolver{ nonIOS := &upstreamResolver{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
nsNet: wgIface.GetNet(),
} }
upstreamResolverBase.upstreamClient = nonIOS upstreamResolverBase.upstreamClient = nonIOS
return nonIOS, nil return nonIOS, nil
} }
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
// TODO: Check if upstream DNS server is routed through a peer before using netstack.
// Similar to iOS logic, we should determine if the DNS server is reachable directly
// or needs to go through the tunnel, and only use netstack when necessary.
// For now, only use netstack on JS platform where direct access is not possible.
if u.nsNet != nil && runtime.GOOS == "js" {
start := time.Now()
reply, err := ExchangeWithNetstack(ctx, u.nsNet, r, upstream)
return reply, time.Since(start), err
}
client := &dns.Client{ client := &dns.Client{
Timeout: ClientTimeout, Timeout: ClientTimeout,
} }

View File

@@ -26,9 +26,7 @@ type upstreamResolverIOS struct {
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
interfaceName string, wgIface WGIface,
ip netip.Addr,
net netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain string,
@@ -37,9 +35,9 @@ func newUpstreamResolver(
ios := &upstreamResolverIOS{ ios := &upstreamResolverIOS{
upstreamResolverBase: upstreamResolverBase, upstreamResolverBase: upstreamResolverBase,
lIP: ip, lIP: wgIface.Address().IP,
lNet: net, lNet: wgIface.Address().Network,
interfaceName: interfaceName, interfaceName: wgIface.Name(),
} }
ios.upstreamClient = ios ios.upstreamClient = ios

View File

@@ -2,13 +2,17 @@ package dns
import ( import (
"context" "context"
"net"
"net/netip" "net/netip"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
) )
@@ -58,7 +62,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".") resolver, _ := newUpstreamResolver(ctx, &mockNetstackProvider{}, nil, nil, ".")
// Convert test servers to netip.AddrPort // Convert test servers to netip.AddrPort
var servers []netip.AddrPort var servers []netip.AddrPort
for _, server := range testCase.InputServers { for _, server := range testCase.InputServers {
@@ -112,6 +116,19 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
} }
} }
type mockNetstackProvider struct{}
func (m *mockNetstackProvider) Name() string { return "mock" }
func (m *mockNetstackProvider) Address() wgaddr.Address { return wgaddr.Address{} }
func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil }
func (m *mockNetstackProvider) IsUserspaceBind() bool { return false }
func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil }
func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil }
func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil }
func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
return "", nil
}
type mockUpstreamResolver struct { type mockUpstreamResolver struct {
r *dns.Msg r *dns.Msg
rtt time.Duration rtt time.Duration

View File

@@ -5,6 +5,8 @@ package dns
import ( import (
"net" "net"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -17,4 +19,5 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
} }

View File

@@ -1,6 +1,8 @@
package dns package dns
import ( import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -12,5 +14,6 @@ type WGIface interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
GetInterfaceGUIDString() (string, error) GetInterfaceGUIDString() (string, error)
} }

View File

@@ -1748,22 +1748,26 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
} }
e.syncMsgMux.Unlock() e.syncMsgMux.Unlock()
var results []relay.ProbeResult
if waitForResult {
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
}
e.statusRecorder.UpdateRelayStates(results)
// Skip STUN/TURN probing for JS/WASM as it's not available
relayHealthy := true relayHealthy := true
for _, res := range results { if runtime.GOOS != "js" {
if res.Err != nil { var results []relay.ProbeResult
relayHealthy = false if waitForResult {
break results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
} else {
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
} }
e.statusRecorder.UpdateRelayStates(results)
for _, res := range results {
if res.Err != nil {
relayHealthy = false
break
}
}
log.Debugf("relay health check: healthy=%t", relayHealthy)
} }
log.Debugf("relay health check: healthy=%t", relayHealthy)
allHealthy := signalHealthy && managementHealthy && relayHealthy allHealthy := signalHealthy && managementHealthy && relayHealthy
log.Debugf("all health checks completed: healthy=%t", allHealthy) log.Debugf("all health checks completed: healthy=%t", allHealthy)

View File

@@ -14,6 +14,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -158,6 +159,7 @@ type FullStatus struct {
NSGroupStates []NSGroupState NSGroupStates []NSGroupState
NumOfForwardingRules int NumOfForwardingRules int
LazyConnectionEnabled bool LazyConnectionEnabled bool
Events []*proto.SystemEvent
} }
type StatusChangeSubscription struct { type StatusChangeSubscription struct {
@@ -981,6 +983,7 @@ func (d *Status) GetFullStatus() FullStatus {
} }
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...) fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
fullStatus.Events = d.GetEventHistory()
return fullStatus return fullStatus
} }
@@ -1181,3 +1184,97 @@ type EventSubscription struct {
func (s *EventSubscription) Events() <-chan *proto.SystemEvent { func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
return s.events return s.events
} }
// ToProto converts FullStatus to proto.FullStatus.
func (fs FullStatus) ToProto() *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fs.ManagementState.URL
pbFullStatus.ManagementState.Connected = fs.ManagementState.Connected
if err := fs.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fs.SignalState.URL
pbFullStatus.SignalState.Connected = fs.SignalState.Connected
if err := fs.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fs.LazyConnectionEnabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fs.LocalPeerState.Routes)
for _, peerState := range fs.Peers {
networks := maps.Keys(peerState.GetRoutes())
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: networks,
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fs.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fs.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
pbFullStatus.Events = fs.Events
return &pbFullStatus
}

View File

@@ -17,13 +17,13 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/common"
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
iface "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
@@ -38,11 +38,6 @@ type internalDNATer interface {
AddInternalDNATMapping(netip.Addr, netip.Addr) error AddInternalDNATMapping(netip.Addr, netip.Addr) error
} }
type wgInterface interface {
Name() string
Address() wgaddr.Address
}
type DnsInterceptor struct { type DnsInterceptor struct {
mu sync.RWMutex mu sync.RWMutex
route *route.Route route *route.Route
@@ -52,7 +47,7 @@ type DnsInterceptor struct {
dnsServer nbdns.Server dnsServer nbdns.Server
currentPeerKey string currentPeerKey string
interceptedDomains domainMap interceptedDomains domainMap
wgInterface wgInterface wgInterface iface.WGIface
peerStore *peerstore.Store peerStore *peerstore.Store
firewall firewall.Manager firewall firewall.Manager
fakeIPManager *fakeip.Manager fakeIPManager *fakeip.Manager
@@ -250,12 +245,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return
}
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
@@ -264,20 +253,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel() defer cancel()
startTime := time.Now() reply := d.queryUpstreamDNS(ctx, w, r, upstream, upstreamIP, peerKey, logger)
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream) if reply == nil {
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
return return
} }
@@ -586,6 +563,44 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
return return
} }
// queryUpstreamDNS queries the upstream DNS server using netstack if available, otherwise uses regular client.
// Returns the DNS reply on success, or nil on error (error responses are written internally).
func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream string, upstreamIP netip.Addr, peerKey string, logger *log.Entry) *dns.Msg {
startTime := time.Now()
nsNet := d.wgInterface.GetNet()
var reply *dns.Msg
var err error
if nsNet != nil {
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
} else {
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
if clientErr != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
return nil
}
reply, _, err = nbdns.ExchangeWithFallback(ctx, client, r, upstream)
}
if err == nil {
return reply
}
if errors.Is(err, context.DeadlineExceeded) {
elapsed := time.Since(startTime)
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
} else {
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
}
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
logger.Errorf("failed writing DNS response: %v", err)
}
return nil
}
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string { func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
if d.statusRecorder == nil { if d.statusRecorder == nil {
return "" return ""

View File

@@ -4,6 +4,8 @@ import (
"net" "net"
"net/netip" "net/netip"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -18,4 +20,5 @@ type wgIfaceBase interface {
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
GetNet() *netstack.Net
} }

View File

@@ -173,20 +173,9 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
log.SetLevel(level) log.SetLevel(level)
if s.connectClient == nil { if s.connectClient != nil {
return nil, fmt.Errorf("connect client not initialized") s.connectClient.SetLogLevel(level)
} }
engine := s.connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("engine not initialized")
}
fwManager := engine.GetFirewallManager()
if fwManager == nil {
return nil, fmt.Errorf("firewall manager not initialized")
}
fwManager.SetLogLevel(level)
log.Infof("Log level set to %s", level.String()) log.Infof("Log level set to %s", level.String())

View File

@@ -1,8 +1,6 @@
package server package server
import ( import (
"context"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
@@ -29,8 +27,3 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo
} }
} }
} }
func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) {
events := s.statusRecorder.GetEventHistory()
return &proto.GetEventsResponse{Events: events}, nil
}

View File

@@ -13,15 +13,12 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -1067,11 +1064,9 @@ func (s *Server) Status(
if msg.GetFullPeerStatus { if msg.GetFullPeerStatus {
s.runProbes(msg.ShouldRunProbes) s.runProbes(msg.ShouldRunProbes)
fullStatus := s.statusRecorder.GetFullStatus() fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus := fullStatus.ToProto()
pbFullStatus.Events = s.statusRecorder.GetEventHistory() pbFullStatus.Events = s.statusRecorder.GetEventHistory()
pbFullStatus.SshServerState = s.getSSHServerState() pbFullStatus.SshServerState = s.getSSHServerState()
statusResponse.FullStatus = pbFullStatus statusResponse.FullStatus = pbFullStatus
} }
@@ -1600,94 +1595,6 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
return defaultDuration return defaultDuration
} }
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
if err := fullStatus.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
if err := fullStatus.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
SshHostKey: peerState.SSHHostKey,
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
// sendTerminalNotification sends a terminal notification message // sendTerminalNotification sends a terminal notification message
// to inform the user that the NetBird connection session has expired. // to inform the user that the NetBird connection session has expired.
func sendTerminalNotification() error { func sendTerminalNotification() error {

View File

@@ -325,61 +325,64 @@ func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) {
} }
} }
func ParseToJSON(overview OutputOverview) (string, error) { // JSON returns the status overview as a JSON string.
jsonBytes, err := json.Marshal(overview) func (o *OutputOverview) JSON() (string, error) {
jsonBytes, err := json.Marshal(o)
if err != nil { if err != nil {
return "", fmt.Errorf("json marshal failed") return "", fmt.Errorf("json marshal failed")
} }
return string(jsonBytes), err return string(jsonBytes), err
} }
func ParseToYAML(overview OutputOverview) (string, error) { // YAML returns the status overview as a YAML string.
yamlBytes, err := yaml.Marshal(overview) func (o *OutputOverview) YAML() (string, error) {
yamlBytes, err := yaml.Marshal(o)
if err != nil { if err != nil {
return "", fmt.Errorf("yaml marshal failed") return "", fmt.Errorf("yaml marshal failed")
} }
return string(yamlBytes), nil return string(yamlBytes), nil
} }
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string { // GeneralSummary returns a general summary of the status overview.
func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
var managementConnString string var managementConnString string
if overview.ManagementState.Connected { if o.ManagementState.Connected {
managementConnString = "Connected" managementConnString = "Connected"
if showURL { if showURL {
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL) managementConnString = fmt.Sprintf("%s to %s", managementConnString, o.ManagementState.URL)
} }
} else { } else {
managementConnString = "Disconnected" managementConnString = "Disconnected"
if overview.ManagementState.Error != "" { if o.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error) managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, o.ManagementState.Error)
} }
} }
var signalConnString string var signalConnString string
if overview.SignalState.Connected { if o.SignalState.Connected {
signalConnString = "Connected" signalConnString = "Connected"
if showURL { if showURL {
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL) signalConnString = fmt.Sprintf("%s to %s", signalConnString, o.SignalState.URL)
} }
} else { } else {
signalConnString = "Disconnected" signalConnString = "Disconnected"
if overview.SignalState.Error != "" { if o.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error) signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, o.SignalState.Error)
} }
} }
interfaceTypeString := "Userspace" interfaceTypeString := "Userspace"
interfaceIP := overview.IP interfaceIP := o.IP
if overview.KernelInterface { if o.KernelInterface {
interfaceTypeString = "Kernel" interfaceTypeString = "Kernel"
} else if overview.IP == "" { } else if o.IP == "" {
interfaceTypeString = "N/A" interfaceTypeString = "N/A"
interfaceIP = "N/A" interfaceIP = "N/A"
} }
var relaysString string var relaysString string
if showRelays { if showRelays {
for _, relay := range overview.Relays.Details { for _, relay := range o.Relays.Details {
available := "Available" available := "Available"
reason := "" reason := ""
@@ -395,18 +398,18 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
} }
} else { } else {
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total) relaysString = fmt.Sprintf("%d/%d Available", o.Relays.Available, o.Relays.Total)
} }
networks := "-" networks := "-"
if len(overview.Networks) > 0 { if len(o.Networks) > 0 {
sort.Strings(overview.Networks) sort.Strings(o.Networks)
networks = strings.Join(overview.Networks, ", ") networks = strings.Join(o.Networks, ", ")
} }
var dnsServersString string var dnsServersString string
if showNameServers { if showNameServers {
for _, nsServerGroup := range overview.NSServerGroups { for _, nsServerGroup := range o.NSServerGroups {
enabled := "Available" enabled := "Available"
if !nsServerGroup.Enabled { if !nsServerGroup.Enabled {
enabled = "Unavailable" enabled = "Unavailable"
@@ -430,25 +433,25 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
) )
} }
} else { } else {
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups)) dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(o.NSServerGroups), len(o.NSServerGroups))
} }
rosenpassEnabledStatus := "false" rosenpassEnabledStatus := "false"
if overview.RosenpassEnabled { if o.RosenpassEnabled {
rosenpassEnabledStatus = "true" rosenpassEnabledStatus = "true"
if overview.RosenpassPermissive { if o.RosenpassPermissive {
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
} }
} }
lazyConnectionEnabledStatus := "false" lazyConnectionEnabledStatus := "false"
if overview.LazyConnectionEnabled { if o.LazyConnectionEnabled {
lazyConnectionEnabledStatus = "true" lazyConnectionEnabledStatus = "true"
} }
sshServerStatus := "Disabled" sshServerStatus := "Disabled"
if overview.SSHServerState.Enabled { if o.SSHServerState.Enabled {
sessionCount := len(overview.SSHServerState.Sessions) sessionCount := len(o.SSHServerState.Sessions)
if sessionCount > 0 { if sessionCount > 0 {
sessionWord := "session" sessionWord := "session"
if sessionCount > 1 { if sessionCount > 1 {
@@ -460,7 +463,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
} }
if showSSHSessions && sessionCount > 0 { if showSSHSessions && sessionCount > 0 {
for _, session := range overview.SSHServerState.Sessions { for _, session := range o.SSHServerState.Sessions {
var sessionDisplay string var sessionDisplay string
if session.JWTUsername != "" { if session.JWTUsername != "" {
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s", sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
@@ -484,7 +487,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
} }
} }
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
goos := runtime.GOOS goos := runtime.GOOS
goarch := runtime.GOARCH goarch := runtime.GOARCH
@@ -512,30 +515,31 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
"Forwarding rules: %d\n"+ "Forwarding rules: %d\n"+
"Peers count: %s\n", "Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm), fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion, o.DaemonVersion,
version.NetbirdVersion(), version.NetbirdVersion(),
overview.ProfileName, o.ProfileName,
managementConnString, managementConnString,
signalConnString, signalConnString,
relaysString, relaysString,
dnsServersString, dnsServersString,
domain.Domain(overview.FQDN).SafeString(), domain.Domain(o.FQDN).SafeString(),
interfaceIP, interfaceIP,
interfaceTypeString, interfaceTypeString,
rosenpassEnabledStatus, rosenpassEnabledStatus,
lazyConnectionEnabledStatus, lazyConnectionEnabledStatus,
sshServerStatus, sshServerStatus,
networks, networks,
overview.NumberOfForwardingRules, o.NumberOfForwardingRules,
peersCountString, peersCountString,
) )
return summary return summary
} }
func ParseToFullDetailSummary(overview OutputOverview) string { // FullDetailSummary returns a full detailed summary with peer details and events.
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive) func (o *OutputOverview) FullDetailSummary() string {
parsedEventsString := parseEvents(overview.Events) parsedPeersString := parsePeers(o.Peers, o.RosenpassEnabled, o.RosenpassPermissive)
summary := ParseGeneralSummary(overview, true, true, true, true) parsedEventsString := parseEvents(o.Events)
summary := o.GeneralSummary(true, true, true, true)
return fmt.Sprintf( return fmt.Sprintf(
"Peers detail:"+ "Peers detail:"+

View File

@@ -268,7 +268,7 @@ func TestSortingOfPeers(t *testing.T) {
} }
func TestParsingToJSON(t *testing.T) { func TestParsingToJSON(t *testing.T) {
jsonString, _ := ParseToJSON(overview) jsonString, _ := overview.JSON()
//@formatter:off //@formatter:off
expectedJSONString := ` expectedJSONString := `
@@ -404,7 +404,7 @@ func TestParsingToJSON(t *testing.T) {
} }
func TestParsingToYAML(t *testing.T) { func TestParsingToYAML(t *testing.T) {
yaml, _ := ParseToYAML(overview) yaml, _ := overview.YAML()
expectedYAML := expectedYAML :=
`peers: `peers:
@@ -511,7 +511,7 @@ func TestParsingToDetail(t *testing.T) {
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate) lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake) lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := ParseToFullDetailSummary(overview) detail := overview.FullDetailSummary()
expectedDetail := fmt.Sprintf( expectedDetail := fmt.Sprintf(
`Peers detail: `Peers detail:
@@ -575,7 +575,7 @@ Peers count: 2/2 Connected
} }
func TestParsingToShortVersion(t *testing.T) { func TestParsingToShortVersion(t *testing.T) {
shortVersion := ParseGeneralSummary(overview, false, false, false, false) shortVersion := overview.GeneralSummary(false, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + ` expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1 Daemon version: 0.14.1

View File

@@ -441,7 +441,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string var postUpStatusOutput string
if postUpStatus != nil { if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName) overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) postUpStatusOutput = overview.FullDetailSummary()
} }
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput) statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
@@ -458,7 +458,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string var preDownStatusOutput string
if preDownStatus != nil { if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName) overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) preDownStatusOutput = overview.FullDetailSummary()
} }
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
time.Now().Format(time.RFC3339), params.duration) time.Now().Format(time.RFC3339), params.duration)
@@ -595,7 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string var statusOutput string
if statusResp != nil { if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName) overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
statusOutput = nbstatus.ParseToFullDetailSummary(overview) statusOutput = overview.FullDetailSummary()
} }
request := &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{

View File

@@ -9,20 +9,31 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/encoding/protojson"
netbird "github.com/netbirdio/netbird/client/embed" netbird "github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/client/proto"
sshdetection "github.com/netbirdio/netbird/client/ssh/detection" sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh" "github.com/netbirdio/netbird/client/wasm/internal/ssh"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
) )
const ( const (
clientStartTimeout = 30 * time.Second clientStartTimeout = 30 * time.Second
clientStopTimeout = 10 * time.Second clientStopTimeout = 10 * time.Second
defaultLogLevel = "warn" pingTimeout = 10 * time.Second
defaultSSHDetectionTimeout = 20 * time.Second defaultLogLevel = "warn"
defaultSSHDetectionTimeout = 20 * time.Second
defaultPeerConnectionTimeout = 60 * time.Second
peerConnectionPollInterval = 500 * time.Millisecond
icmpEchoRequest = 8
icmpCodeEcho = 0
pingBufferSize = 1500
) )
func main() { func main() {
@@ -113,18 +124,45 @@ func createStopMethod(client *netbird.Client) js.Func {
}) })
} }
// validateSSHArgs validates SSH connection arguments
func validateSSHArgs(args []js.Value) (host string, port int, username string, err js.Value) {
if len(args) < 2 {
return "", 0, "", js.ValueOf("error: requires host and port")
}
if args[0].Type() != js.TypeString {
return "", 0, "", js.ValueOf("host parameter must be a string")
}
if args[1].Type() != js.TypeNumber {
return "", 0, "", js.ValueOf("port parameter must be a number")
}
host = args[0].String()
port = args[1].Int()
username = "root"
if len(args) > 2 {
if args[2].Type() == js.TypeString && args[2].String() != "" {
username = args[2].String()
} else if args[2].Type() != js.TypeString {
return "", 0, "", js.ValueOf("username parameter must be a string")
}
}
return host, port, username, js.Undefined()
}
// createSSHMethod creates the SSH connection method // createSSHMethod creates the SSH connection method
func createSSHMethod(client *netbird.Client) js.Func { func createSSHMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any { return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 { host, port, username, validationErr := validateSSHArgs(args)
return js.ValueOf("error: requires host and port") if !validationErr.IsUndefined() {
} if validationErr.Type() == js.TypeString && validationErr.String() == "error: requires host and port" {
return validationErr
host := args[0].String() }
port := args[1].Int() return createPromise(func(resolve, reject js.Value) {
username := "root" reject.Invoke(validationErr)
if len(args) > 2 && args[2].String() != "" { })
username = args[2].String()
} }
var jwtToken string var jwtToken string
@@ -133,6 +171,9 @@ func createSSHMethod(client *netbird.Client) js.Func {
} }
return createPromise(func(resolve, reject js.Value) { return createPromise(func(resolve, reject js.Value) {
// Note: Don't wait for peer connection here - lazy connection manager
// will open the connection when Dial is called in ssh.Connect().
sshClient := ssh.NewClient(client) sshClient := ssh.NewClient(client)
if err := sshClient.Connect(host, port, username, jwtToken); err != nil { if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
@@ -154,6 +195,110 @@ func createSSHMethod(client *netbird.Client) js.Func {
}) })
} }
func performPing(client *netbird.Client, hostname string) {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
start := time.Now()
conn, err := client.Dial(ctx, "ping", hostname)
if err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
return
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("failed to close ping connection: %v", err)
}
}()
icmpData := make([]byte, 8)
icmpData[0] = icmpEchoRequest
icmpData[1] = icmpCodeEcho
if _, err := conn.Write(icmpData); err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s write failed: %v", hostname, err))
return
}
buf := make([]byte, pingBufferSize)
if _, err := conn.Read(buf); err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s read failed: %v", hostname, err))
return
}
latency := time.Since(start)
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()))
}
func performPingTCP(client *netbird.Client, hostname string, port int) {
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
address := fmt.Sprintf("%s:%d", hostname, port)
start := time.Now()
conn, err := client.Dial(ctx, "tcp", address)
if err != nil {
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
return
}
latency := time.Since(start)
if err := conn.Close(); err != nil {
log.Debugf("failed to close TCP connection: %v", err)
}
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()))
}
// createPingMethod creates the ping method
func createPingMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: hostname required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
hostname := args[0].String()
return createPromise(func(resolve, reject js.Value) {
performPing(client, hostname)
resolve.Invoke(js.Undefined())
})
})
}
// createPingTCPMethod creates the pingtcp method
func createPingTCPMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: hostname and port required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
if args[1].Type() != js.TypeNumber {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("port parameter must be a number"))
})
}
hostname := args[0].String()
port := args[1].Int()
return createPromise(func(resolve, reject js.Value) {
performPingTCP(client, hostname, port)
resolve.Invoke(js.Undefined())
})
})
}
// createProxyRequestMethod creates the proxyRequest method // createProxyRequestMethod creates the proxyRequest method
func createProxyRequestMethod(client *netbird.Client) js.Func { func createProxyRequestMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any { return js.FuncOf(func(this js.Value, args []js.Value) any {
@@ -162,6 +307,11 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
} }
request := args[0] request := args[0]
if request.Type() != js.TypeObject {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("request parameter must be an object"))
})
}
return createPromise(func(resolve, reject js.Value) { return createPromise(func(resolve, reject js.Value) {
response, err := http.ProxyRequest(client, request) response, err := http.ProxyRequest(client, request)
@@ -181,23 +331,255 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
return js.ValueOf("error: hostname and port required") return js.ValueOf("error: hostname and port required")
} }
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
})
}
if args[1].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("port parameter must be a string"))
})
}
proxy := rdp.NewRDCleanPathProxy(client) proxy := rdp.NewRDCleanPathProxy(client)
return proxy.CreateProxy(args[0].String(), args[1].String()) return proxy.CreateProxy(args[0].String(), args[1].String())
}) })
} }
// getStatusOverview is a helper to get the status overview
func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) {
fullStatus, err := client.Status()
if err != nil {
return nbstatus.OutputOverview{}, err
}
pbFullStatus := fullStatus.ToProto()
statusResp := &proto.StatusResponse{
DaemonVersion: version.NetbirdVersion(),
FullStatus: pbFullStatus,
}
return nbstatus.ConvertToStatusOutputOverview(statusResp, false, "", nil, nil, nil, "", ""), nil
}
// createStatusMethod creates the status method that returns JSON
func createStatusMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
jsonStr, err := overview.JSON()
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
jsonObj := js.Global().Get("JSON").Call("parse", jsonStr)
resolve.Invoke(jsonObj)
})
})
}
// createStatusSummaryMethod creates the statusSummary method
func createStatusSummaryMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
overview, err := getStatusOverview(client)
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
summary := overview.GeneralSummary(false, false, false, false)
js.Global().Get("console").Call("log", summary)
resolve.Invoke(js.Undefined())
})
})
}
// createStatusDetailMethod creates the statusDetail method
func createStatusDetailMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
log.Info("statusDetail called")
return createPromise(func(resolve, reject js.Value) {
log.Info("statusDetail: getting overview")
overview, err := getStatusOverview(client)
if err != nil {
log.Errorf("statusDetail: getStatusOverview failed: %v", err)
reject.Invoke(js.ValueOf(err.Error()))
return
}
log.Info("statusDetail: generating full detail summary")
detail := overview.FullDetailSummary()
log.Infof("statusDetail: detail length=%d", len(detail))
js.Global().Get("console").Call("log", detail)
resolve.Invoke(js.Undefined())
})
})
}
// createWaitForPeerConnectionMethod creates a method that waits for a peer to be connected
func createWaitForPeerConnectionMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
if len(args) < 1 {
reject.Invoke(js.ValueOf("peer IP address required"))
return
}
peerIP := args[0].String()
timeoutMs := int(defaultPeerConnectionTimeout.Milliseconds())
if len(args) > 1 && !args[1].IsUndefined() && !args[1].IsNull() {
timeoutMs = args[1].Int()
}
timeout := time.Duration(timeoutMs) * time.Millisecond
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
log.Infof("Waiting for peer %s to be connected (timeout: %v)", peerIP, timeout)
ticker := time.NewTicker(peerConnectionPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
reject.Invoke(js.ValueOf(fmt.Sprintf("timeout waiting for peer %s to connect", peerIP)))
return
case <-ticker.C:
overview, err := getStatusOverview(client)
if err != nil {
log.Debugf("Error getting status: %v", err)
continue
}
for _, peer := range overview.Peers.Details {
if peer.IP == peerIP && peer.Status == "Connected" {
log.Infof("Peer %s is now connected (type: %s)", peerIP, peer.ConnType)
resolve.Invoke(js.ValueOf(map[string]interface{}{
"ip": peer.IP,
"status": peer.Status,
"connType": peer.ConnType,
}))
return
}
}
log.Tracef("Peer %s not yet connected, waiting...", peerIP)
}
}
})
})
}
// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON
func createGetSyncResponseMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
return createPromise(func(resolve, reject js.Value) {
syncResp, err := client.GetLatestSyncResponse()
if err != nil {
reject.Invoke(js.ValueOf(err.Error()))
return
}
options := protojson.MarshalOptions{
EmitUnpopulated: true,
UseProtoNames: true,
AllowPartial: true,
}
jsonBytes, err := options.Marshal(syncResp)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("marshal sync response: %v", err)))
return
}
jsonObj := js.Global().Get("JSON").Call("parse", string(jsonBytes))
resolve.Invoke(jsonObj)
})
})
}
// createSetLogLevelMethod creates the setLogLevel method to dynamically change logging level
func createSetLogLevelMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
return js.ValueOf("error: log level required")
}
if args[0].Type() != js.TypeString {
return createPromise(func(resolve, reject js.Value) {
reject.Invoke(js.ValueOf("log level parameter must be a string"))
})
}
logLevel := args[0].String()
return createPromise(func(resolve, reject js.Value) {
if err := client.SetLogLevel(logLevel); err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("set log level: %v", err)))
return
}
log.Infof("Log level set to: %s", logLevel)
resolve.Invoke(js.ValueOf(true))
})
})
}
// createPromise is a helper to create JavaScript promises // createPromise is a helper to create JavaScript promises
func createPromise(handler func(resolve, reject js.Value)) js.Value { func createPromise(handler func(resolve, reject js.Value)) js.Value {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0] resolve := promiseArgs[0]
reject := promiseArgs[1] reject := promiseArgs[1]
go handler(resolve, reject) // Wrap reject to always log the error
loggingReject := js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) > 0 {
log.Errorf("Promise rejected: %v", args[0])
js.Global().Get("console").Call("error", "WASM Promise rejected:", args[0])
}
reject.Invoke(args[0])
return nil
})
go handler(resolve, loggingReject.Value)
return nil return nil
})) }))
} }
// waitForPeerConnected waits for a peer with the given IP to be connected
func waitForPeerConnected(ctx context.Context, client *netbird.Client, peerIP string) error {
log.Infof("Waiting for peer %s to be connected before operation", peerIP)
ticker := time.NewTicker(peerConnectionPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return fmt.Errorf("timeout waiting for peer %s to connect", peerIP)
case <-ticker.C:
overview, err := getStatusOverview(client)
if err != nil {
log.Debugf("Error getting status while waiting for peer: %v", err)
continue
}
for _, peer := range overview.Peers.Details {
if peer.IP == peerIP && peer.Status == "Connected" {
log.Infof("Peer %s is now connected (type: %s), proceeding with operation", peerIP, peer.ConnType)
return nil
}
}
log.Tracef("Peer %s not yet connected, waiting...", peerIP)
}
}
}
// createDetectSSHServerMethod creates the SSH server detection method // createDetectSSHServerMethod creates the SSH server detection method
func createDetectSSHServerMethod(client *netbird.Client) js.Func { func createDetectSSHServerMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any { return js.FuncOf(func(this js.Value, args []js.Value) any {
@@ -220,6 +602,10 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
defer cancel() defer cancel()
// Note: Don't wait for peer connection here - lazy connection manager
// will open the connection when Dial is called. Waiting would cause
// a deadlock since lazy connections only open on traffic.
serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port) serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port)
if err != nil { if err != nil {
reject.Invoke(err.Error()) reject.Invoke(err.Error())
@@ -237,17 +623,25 @@ func createClientObject(client *netbird.Client) js.Value {
obj["start"] = createStartMethod(client) obj["start"] = createStartMethod(client)
obj["stop"] = createStopMethod(client) obj["stop"] = createStopMethod(client)
obj["ping"] = createPingMethod(client)
obj["pingtcp"] = createPingTCPMethod(client)
obj["detectSSHServerType"] = createDetectSSHServerMethod(client) obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
obj["createSSHConnection"] = createSSHMethod(client) obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client) obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client) obj["createRDPProxy"] = createRDPProxyMethod(client)
obj["status"] = createStatusMethod(client)
obj["statusSummary"] = createStatusSummaryMethod(client)
obj["statusDetail"] = createStatusDetailMethod(client)
obj["waitForPeerConnection"] = createWaitForPeerConnectionMethod(client)
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
obj["setLogLevel"] = createSetLogLevelMethod(client)
return js.ValueOf(obj) return js.ValueOf(obj)
} }
// netBirdClientConstructor acts as a JavaScript constructor function // netBirdClientConstructor acts as a JavaScript constructor function
func netBirdClientConstructor(this js.Value, args []js.Value) any { func netBirdClientConstructor(_ js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
resolve := promiseArgs[0] resolve := promiseArgs[0]
reject := promiseArgs[1] reject := promiseArgs[1]
@@ -263,7 +657,10 @@ func netBirdClientConstructor(this js.Value, args []js.Value) any {
return return
} }
if err := util.InitLog(options.LogLevel, util.LogConsole); err != nil { // Force trace logging for debugging - must be set BEFORE netbird.New
// as New() will call logrus.SetLevel with options.LogLevel
options.LogLevel = "trace"
if err := util.InitLog("trace", util.LogConsole); err != nil {
log.Warnf("Failed to initialize logging: %v", err) log.Warnf("Failed to initialize logging: %v", err)
} }

8
go.mod
View File

@@ -78,8 +78,8 @@ require (
github.com/pion/logging v0.2.4 github.com/pion/logging v0.2.4
github.com/pion/randutil v0.1.0 github.com/pion/randutil v0.1.0
github.com/pion/stun/v2 v2.0.0 github.com/pion/stun/v2 v2.0.0
github.com/pion/stun/v3 v3.0.0 github.com/pion/stun/v3 v3.1.0
github.com/pion/transport/v3 v3.0.7 github.com/pion/transport/v3 v3.1.1
github.com/pion/turn/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1
github.com/pkg/sftp v1.13.9 github.com/pkg/sftp v1.13.9
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
@@ -241,7 +241,7 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/dtls/v3 v3.0.7 // indirect github.com/pion/dtls/v3 v3.0.9 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pion/turn/v4 v4.1.1 // indirect github.com/pion/turn/v4 v4.1.1 // indirect
@@ -263,7 +263,7 @@ require (
github.com/tklauser/numcpus v0.8.0 // indirect github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.5 // indirect github.com/vishvananda/netns v0.0.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wlynxg/anet v0.0.3 // indirect github.com/wlynxg/anet v0.0.5 // indirect
github.com/yuin/goldmark v1.7.8 // indirect github.com/yuin/goldmark v1.7.8 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect github.com/zeebo/blake3 v0.2.3 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect

16
go.sum
View File

@@ -444,8 +444,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA= github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q= github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8= github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
@@ -455,14 +455,14 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0= github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ= github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw= github.com/pion/stun/v3 v3.1.0 h1:bS1jjT3tGWZ4UPmIUeyalOylamTMTFg1OvXtY/r6seM=
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU= github.com/pion/stun/v3 v3.1.0/go.mod h1:egmx1CUcfSSGJxQCOjtVlomfPqmQ58BibPyuOWNGQEU=
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g= github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo= github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0= github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0= github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo= github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8= github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE= github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc= github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
@@ -574,8 +574,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg= github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=

View File

@@ -26,6 +26,8 @@ type UserDataCache interface {
Get(ctx context.Context, key string) (*idp.UserData, error) Get(ctx context.Context, key string) (*idp.UserData, error)
Set(ctx context.Context, key string, value *idp.UserData, expiration time.Duration) error Set(ctx context.Context, key string, value *idp.UserData, expiration time.Duration) error
Delete(ctx context.Context, key string) error Delete(ctx context.Context, key string) error
GetUsers(ctx context.Context, key string) ([]*idp.UserData, error)
SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error
} }
// UserDataCacheImpl is a struct that implements the UserDataCache interface. // UserDataCacheImpl is a struct that implements the UserDataCache interface.
@@ -51,6 +53,29 @@ func (u *UserDataCacheImpl) Delete(ctx context.Context, key string) error {
return u.cache.Delete(ctx, key) return u.cache.Delete(ctx, key)
} }
func (u *UserDataCacheImpl) GetUsers(ctx context.Context, key string) ([]*idp.UserData, error) {
var users []*idp.UserData
v, err := u.cache.Get(ctx, key, &users)
if err != nil {
return nil, err
}
switch v := v.(type) {
case []*idp.UserData:
return v, nil
case *[]*idp.UserData:
return *v, nil
case []byte:
return unmarshalUserData(v)
}
return nil, fmt.Errorf("unexpected type: %T", v)
}
func (u *UserDataCacheImpl) SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error {
return u.cache.Set(ctx, key, users, store.WithExpiration(expiration))
}
// NewUserDataCache creates a new UserDataCacheImpl object. // NewUserDataCache creates a new UserDataCacheImpl object.
func NewUserDataCache(store store.StoreInterface) *UserDataCacheImpl { func NewUserDataCache(store store.StoreInterface) *UserDataCacheImpl {
simpleCache := cache.New[any](store) simpleCache := cache.New[any](store)

View File

@@ -178,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
m.patUsageTracker.IncrementUsage(token) m.patUsageTracker.IncrementUsage(token)
} }
if m.rateLimiter != nil { if m.rateLimiter != nil && !isTerraformRequest(r) {
if !m.rateLimiter.Allow(token) { if !m.rateLimiter.Allow(token) {
return r, status.Errorf(status.TooManyRequests, "too many requests") return r, status.Errorf(status.TooManyRequests, "too many requests")
} }
@@ -214,6 +214,11 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
return nbcontext.SetUserAuthInRequest(r, userAuth), nil return nbcontext.SetUserAuthInRequest(r, userAuth), nil
} }
func isTerraformRequest(r *http.Request) bool {
ua := strings.ToLower(r.Header.Get("User-Agent"))
return strings.Contains(ua, "terraform")
}
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts // getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
// the JWT token from the Authorization header. // the JWT token from the Authorization header.
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) { func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {

View File

@@ -508,6 +508,103 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
handler.ServeHTTP(rec, req) handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again") assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
}) })
t.Run("Terraform User Agent Not Rate Limited", func(t *testing.T) {
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Test various Terraform user agent formats
terraformUserAgents := []string{
"Terraform/1.5.0",
"terraform/1.0.0",
"Terraform-Provider/2.0.0",
"Mozilla/5.0 (compatible; Terraform/1.3.0)",
}
for _, userAgent := range terraformUserAgents {
t.Run("UserAgent: "+userAgent, func(t *testing.T) {
successCount := 0
for i := 0; i < 10; i++ {
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
req.Header.Set("User-Agent", userAgent)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
successCount++
}
}
assert.Equal(t, 10, successCount, "All Terraform user agent requests should succeed (not rate limited)")
})
}
})
t.Run("Non-Terraform User Agent With PAT Is Rate Limited", func(t *testing.T) {
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
nil,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
req.Header.Set("User-Agent", "curl/7.68.0")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
req.Header.Set("User-Agent", "curl/7.68.0")
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
})
} }
func TestAuthMiddleware_Handler_Child(t *testing.T) { func TestAuthMiddleware_Handler_Child(t *testing.T) {

View File

@@ -911,10 +911,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
accountUsers := []*types.User{} accountUsers := []*types.User{}
switch { switch {
case allowed: case allowed:
start := time.Now()
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.WithContext(ctx).Tracef("Got %d users from account %s after %s", len(accountUsers), accountID, time.Since(start))
case user != nil && user.AccountID == accountID: case user != nil && user.AccountID == accountID:
accountUsers = append(accountUsers, user) accountUsers = append(accountUsers, user)
default: default:
@@ -933,23 +935,40 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) { if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
users := make(map[string]userLoggedInOnce, len(accountUsers)) users := make(map[string]userLoggedInOnce, len(accountUsers))
usersFromIntegration := make([]*idp.UserData, 0) usersFromIntegration := make([]*idp.UserData, 0)
filtered := make(map[string]*idp.UserData, len(accountUsers))
log.WithContext(ctx).Tracef("Querying users from IDP for account %s", accountID)
start := time.Now()
integrationKeys := make(map[string]struct{})
for _, user := range accountUsers { for _, user := range accountUsers {
if user.Issued == types.UserIssuedIntegration { if user.Issued == types.UserIssuedIntegration {
key := user.IntegrationReference.CacheKey(accountID, user.Id) integrationKeys[user.IntegrationReference.CacheKey(accountID)] = struct{}{}
info, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
log.WithContext(ctx).Infof("Get ExternalCache for key: %s, error: %s", key, err)
users[user.Id] = true
continue
}
usersFromIntegration = append(usersFromIntegration, info)
continue continue
} }
if !user.IsServiceUser { if !user.IsServiceUser {
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero()) users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
} }
} }
for key := range integrationKeys {
usersData, err := am.externalCacheManager.GetUsers(am.ctx, key)
if err != nil {
log.WithContext(ctx).Debugf("GetUsers from ExternalCache for key: %s, error: %s", key, err)
continue
}
for _, ud := range usersData {
filtered[ud.ID] = ud
}
}
for _, ud := range filtered {
usersFromIntegration = append(usersFromIntegration, ud)
}
log.WithContext(ctx).Tracef("Got user info from external cache after %s", time.Since(start))
start = time.Now()
queriedUsers, err = am.lookupCache(ctx, users, accountID) queriedUsers, err = am.lookupCache(ctx, users, accountID)
log.WithContext(ctx).Tracef("Got user info from cache for %d users after %s", len(queriedUsers), time.Since(start))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1086,8 +1086,12 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
cacheManager := am.GetExternalCacheManager() cacheManager := am.GetExternalCacheManager()
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id) tud := &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}, time.Minute) cacheKeyUser := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
err = cacheManager.Set(context.Background(), cacheKeyUser, tud, time.Minute)
assert.NoError(t, err)
cacheKeyAccount := externalUser.IntegrationReference.CacheKey(mockAccountID)
err = cacheManager.SetUsers(context.Background(), cacheKeyAccount, []*idp.UserData{tud}, time.Minute)
assert.NoError(t, err) assert.NoError(t, err)
infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID) infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@@ -22,6 +23,7 @@ import (
"github.com/netbirdio/netbird/relay/server" "github.com/netbirdio/netbird/relay/server"
"github.com/netbirdio/netbird/shared/relay/auth" "github.com/netbirdio/netbird/shared/relay/auth"
"github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/stun"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -43,6 +45,10 @@ type Config struct {
LogLevel string LogLevel string
LogFile string LogFile string
HealthcheckListenAddress string HealthcheckListenAddress string
// STUN server configuration
EnableSTUN bool
STUNPorts []int
STUNLogLevel string
} }
func (c Config) Validate() error { func (c Config) Validate() error {
@@ -52,6 +58,25 @@ func (c Config) Validate() error {
if c.AuthSecret == "" { if c.AuthSecret == "" {
return fmt.Errorf("auth secret is required") return fmt.Errorf("auth secret is required")
} }
// Validate STUN configuration
if c.EnableSTUN {
if len(c.STUNPorts) == 0 {
return fmt.Errorf("--stun-ports is required when --enable-stun is set")
}
seen := make(map[int]bool)
for _, port := range c.STUNPorts {
if port <= 0 || port > 65535 {
return fmt.Errorf("invalid STUN port %d: must be between 1 and 65535", port)
}
if seen[port] {
return fmt.Errorf("duplicate STUN port %d", port)
}
seen[port] = true
}
}
return nil return nil
} }
@@ -91,6 +116,9 @@ func init() {
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level") rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file") rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server") rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
rootCmd.PersistentFlags().BoolVar(&cobraConfig.EnableSTUN, "enable-stun", false, "enable embedded STUN server")
rootCmd.PersistentFlags().IntSliceVar(&cobraConfig.STUNPorts, "stun-ports", []int{3478}, "ports for the embedded STUN server (can be specified multiple times or comma-separated)")
rootCmd.PersistentFlags().StringVar(&cobraConfig.STUNLogLevel, "stun-log-level", "info", "log level for STUN server (panic, fatal, error, warn, info, debug, trace)")
setFlagsFromEnvVars(rootCmd) setFlagsFromEnvVars(rootCmd)
} }
@@ -119,21 +147,14 @@ func execute(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to initialize log: %s", err) return fmt.Errorf("failed to initialize log: %s", err)
} }
// Resource creation phase (fail fast before starting any goroutines)
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "") metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
if err != nil { if err != nil {
log.Debugf("setup metrics: %v", err) log.Debugf("setup metrics: %v", err)
return fmt.Errorf("setup metrics: %v", err) return fmt.Errorf("setup metrics: %v", err)
} }
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start metrics server: %v", err)
}
}()
srvListenerCfg := server.ListenerConfig{ srvListenerCfg := server.ListenerConfig{
Address: cobraConfig.ListenAddress, Address: cobraConfig.ListenAddress,
} }
@@ -145,6 +166,12 @@ func execute(cmd *cobra.Command, args []string) error {
} }
srvListenerCfg.TLSConfig = tlsConfig srvListenerCfg.TLSConfig = tlsConfig
// Create STUN listeners early to fail fast
stunListeners, err := createSTUNListeners()
if err != nil {
return err
}
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret)) hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour) authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
@@ -155,60 +182,145 @@ func execute(cmd *cobra.Command, args []string) error {
TLSSupport: tlsSupport, TLSSupport: tlsSupport,
} }
srv, err := server.NewServer(cfg) srv, err := createRelayServer(cfg)
if err != nil { if err != nil {
log.Debugf("failed to create relay server: %v", err) cleanupSTUNListeners(stunListeners)
return fmt.Errorf("failed to create relay server: %v", err) return err
} }
hCfg := healthcheck.Config{
ListenAddress: cobraConfig.HealthcheckListenAddress,
ServiceChecker: srv,
}
httpHealthcheck, err := createHealthCheck(hCfg)
if err != nil {
cleanupSTUNListeners(stunListeners)
return err
}
var stunServer *stun.Server
if len(stunListeners) > 0 {
stunServer = stun.NewServer(stunListeners, cobraConfig.STUNLogLevel)
}
// Start all servers (only after all resources are successfully created)
startServers(&wg, metricsServer, srv, srvListenerCfg, httpHealthcheck, stunServer)
waitForExitSignal()
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
err = shutdownServers(ctx, metricsServer, srv, httpHealthcheck, stunServer)
wg.Wait()
return err
}
func startServers(wg *sync.WaitGroup, metricsServer *metrics.Metrics, srv *server.Server, srvListenerCfg server.ListenerConfig, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) {
wg.Add(1)
go func() {
defer wg.Done()
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start metrics server: %v", err)
}
}()
instanceURL := srv.InstanceURL() instanceURL := srv.InstanceURL()
log.Infof("server will be available on: %s", instanceURL.String()) log.Infof("server will be available on: %s", instanceURL.String())
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
if err := srv.Listen(srvListenerCfg); err != nil { if err := srv.Listen(srvListenerCfg); err != nil {
log.Fatalf("failed to bind server: %s", err) log.Fatalf("failed to bind relay server: %s", err)
} }
}() }()
hCfg := healthcheck.Config{
ListenAddress: cobraConfig.HealthcheckListenAddress,
ServiceChecker: srv,
}
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
log.Debugf("failed to create healthcheck server: %v", err)
return fmt.Errorf("failed to create healthcheck server: %v", err)
}
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("Failed to start healthcheck server: %v", err) log.Fatalf("failed to start healthcheck server: %v", err)
} }
}() }()
// it will block until exit signal if stunServer != nil {
waitForExitSignal() wg.Add(1)
go func() {
defer wg.Done()
if err := stunServer.Listen(); err != nil {
if errors.Is(err, stun.ErrServerClosed) {
return
}
log.Errorf("STUN server error: %v", err)
}
}()
}
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) func shutdownServers(ctx context.Context, metricsServer *metrics.Metrics, srv *server.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) error {
defer cancel() var errs error
var shutDownErrors error
if err := httpHealthcheck.Shutdown(ctx); err != nil { if err := httpHealthcheck.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err)) errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
}
if stunServer != nil {
if err := stunServer.Shutdown(); err != nil {
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
}
} }
if err := srv.Shutdown(ctx); err != nil { if err := srv.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err)) errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
} }
log.Infof("shutting down metrics server") log.Infof("shutting down metrics server")
if err := metricsServer.Shutdown(ctx); err != nil { if err := metricsServer.Shutdown(ctx); err != nil {
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err)) errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
} }
wg.Wait() return errs
return shutDownErrors }
func createHealthCheck(hCfg healthcheck.Config) (*healthcheck.Server, error) {
httpHealthcheck, err := healthcheck.NewServer(hCfg)
if err != nil {
log.Debugf("failed to create healthcheck server: %v", err)
return nil, fmt.Errorf("failed to create healthcheck server: %v", err)
}
return httpHealthcheck, nil
}
func createRelayServer(cfg server.Config) (*server.Server, error) {
srv, err := server.NewServer(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create relay server: %v", err)
}
return srv, nil
}
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
for _, l := range stunListeners {
_ = l.Close()
}
}
func createSTUNListeners() ([]*net.UDPConn, error) {
var stunListeners []*net.UDPConn
if cobraConfig.EnableSTUN {
for _, port := range cobraConfig.STUNPorts {
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
if err != nil {
// Close already opened listeners on failure
cleanupSTUNListeners(stunListeners)
log.Debugf("failed to create STUN listener on port %d: %v", port, err)
return nil, fmt.Errorf("failed to create STUN listener on port %d: %v", port, err)
}
stunListeners = append(stunListeners, listener)
}
}
return stunListeners, nil
} }
func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) { func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {

View File

@@ -8,6 +8,10 @@ pid="$(pgrep -x -f /usr/bin/netbird-ui || true)"
if [ -n "${pid}" ] if [ -n "${pid}" ]
then then
uid="$(cat /proc/"${pid}"/loginuid)" uid="$(cat /proc/"${pid}"/loginuid)"
# loginuid can be 4294967295 (-1) if not set, fall back to process uid
if [ "${uid}" = "4294967295" ] || [ "${uid}" = "-1" ]; then
uid="$(stat -c '%u' /proc/"${pid}")"
fi
username="$(id -nu "${uid}")" username="$(id -nu "${uid}")"
# Only re-run if it was already running # Only re-run if it was already running
pkill -x -f /usr/bin/netbird-ui >/dev/null 2>&1 pkill -x -f /usr/bin/netbird-ui >/dev/null 2>&1

170
stun/server.go Normal file
View File

@@ -0,0 +1,170 @@
// Package stun provides an embedded STUN server for NAT traversal discovery.
package stun
import (
"errors"
"fmt"
"net"
"sync"
"github.com/hashicorp/go-multierror"
nberrors "github.com/netbirdio/netbird/client/errors"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/formatter"
"github.com/pion/stun/v3"
)
// ErrServerClosed is returned by Listen when the server is shut down gracefully.
var ErrServerClosed = errors.New("stun: server closed")
// ErrNoListeners is returned by Listen when no UDP connections were provided.
var ErrNoListeners = errors.New("stun: no listeners configured")
// Server implements a STUN server that responds to binding requests
// with the client's reflexive transport address.
type Server struct {
conns []*net.UDPConn
logger *log.Entry
logLevel log.Level
wg sync.WaitGroup
}
// NewServer creates a new STUN server with the given UDP listeners.
// The caller is responsible for creating and providing the listeners.
// logLevel can be: panic, fatal, error, warn, info, debug, trace
func NewServer(conns []*net.UDPConn, logLevel string) *Server {
level, err := log.ParseLevel(logLevel)
if err != nil {
level = log.InfoLevel
}
// Create a separate logger with its own level setting
// This allows --stun-log-level to work independently of --log-level
stunLogger := log.New()
stunLogger.SetOutput(log.StandardLogger().Out)
stunLogger.SetLevel(level)
// Use the formatter package to set up formatter, ReportCaller, and context hook
formatter.SetTextFormatter(stunLogger)
logger := stunLogger.WithField("component", "stun-server")
logger.Infof("STUN server log level set to: %s", level.String())
return &Server{
conns: conns,
logger: logger,
logLevel: level,
}
}
// Listen starts the STUN server and blocks until the server is shut down.
// Returns ErrServerClosed when shut down gracefully via Shutdown.
// Returns ErrNoListeners if no UDP connections were provided.
func (s *Server) Listen() error {
if len(s.conns) == 0 {
return ErrNoListeners
}
// Start a read loop for each listener
for _, conn := range s.conns {
s.logger.Infof("STUN server listening on %s", conn.LocalAddr())
s.wg.Add(1)
go s.readLoop(conn)
}
s.wg.Wait()
return ErrServerClosed
}
// readLoop continuously reads UDP packets and handles STUN requests.
func (s *Server) readLoop(conn *net.UDPConn) {
defer s.wg.Done()
buf := make([]byte, 1500) // Standard MTU size
for {
n, remoteAddr, err := conn.ReadFromUDP(buf)
if err != nil {
// Check if the connection was closed externally
if errors.Is(err, net.ErrClosed) {
s.logger.Info("UDP connection closed, stopping read loop")
return
}
s.logger.Warnf("failed to read UDP packet: %v", err)
continue
}
// Handle packet in the same goroutine to avoid complexity
// STUN responses are small and fast
s.handlePacket(conn, buf[:n], remoteAddr)
}
}
// handlePacket processes a STUN request and sends a response.
func (s *Server) handlePacket(conn *net.UDPConn, data []byte, addr *net.UDPAddr) {
localPort := conn.LocalAddr().(*net.UDPAddr).Port
s.logger.Debugf("[port:%d] received %d bytes from %s", localPort, len(data), addr)
// Check if it's a STUN message
if !stun.IsMessage(data) {
s.logger.Debugf("[port:%d] not a STUN message (first bytes: %x)", localPort, data[:min(len(data), 8)])
return
}
// Parse the STUN message
msg := &stun.Message{Raw: data}
if err := msg.Decode(); err != nil {
s.logger.Warnf("[port:%d] failed to decode STUN message from %s: %v", localPort, addr, err)
return
}
s.logger.Debugf("[port:%d] received STUN %s from %s (tx=%x)", localPort, msg.Type, addr, msg.TransactionID[:8])
// Only handle binding requests
if msg.Type != stun.BindingRequest {
s.logger.Debugf("[port:%d] ignoring non-binding request: %s", localPort, msg.Type)
return
}
// Build the response
response, err := stun.Build(
stun.NewTransactionIDSetter(msg.TransactionID),
stun.BindingSuccess,
&stun.XORMappedAddress{
IP: addr.IP,
Port: addr.Port,
},
stun.Fingerprint,
)
if err != nil {
s.logger.Errorf("[port:%d] failed to build STUN response: %v", localPort, err)
return
}
// Send the response on the same connection it was received on
n, err := conn.WriteToUDP(response.Raw, addr)
if err != nil {
s.logger.Errorf("[port:%d] failed to send STUN response to %s: %v", localPort, addr, err)
return
}
s.logger.Debugf("[port:%d] sent STUN BindingSuccess to %s (%d bytes) with XORMappedAddress %s:%d", localPort, addr, n, addr.IP, addr.Port)
}
// Shutdown gracefully stops the STUN server.
func (s *Server) Shutdown() error {
s.logger.Info("shutting down STUN server")
var merr *multierror.Error
for _, conn := range s.conns {
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
merr = multierror.Append(merr, fmt.Errorf("close STUN UDP connection: %w", err))
}
}
// Wait for all readLoops to finish
s.wg.Wait()
return nberrors.FormatErrorOrNil(merr)
}

479
stun/server_test.go Normal file
View File

@@ -0,0 +1,479 @@
package stun
import (
"errors"
"fmt"
"math/rand"
"net"
"os"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/pion/stun/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// createTestServer creates a STUN server listening on a random port for testing.
// Returns the server, the listener connection (caller must close), and the server address.
func createTestServer(t testing.TB) (*Server, *net.UDPConn, *net.UDPAddr) {
t.Helper()
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
require.NoError(t, err)
server := NewServer([]*net.UDPConn{conn}, "debug")
return server, conn, conn.LocalAddr().(*net.UDPAddr)
}
// waitForServerReady polls the server with STUN binding requests until it responds.
// This avoids flaky tests on slow CI machines that relied on time.Sleep.
func waitForServerReady(t testing.TB, serverAddr *net.UDPAddr, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
retryInterval := 10 * time.Millisecond
clientConn, err := net.DialUDP("udp", nil, serverAddr)
require.NoError(t, err)
defer clientConn.Close()
buf := make([]byte, 1500)
for time.Now().Before(deadline) {
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
require.NoError(t, err)
_, err = clientConn.Write(msg.Raw)
require.NoError(t, err)
_ = clientConn.SetReadDeadline(time.Now().Add(retryInterval))
n, err := clientConn.Read(buf)
if err != nil {
// Timeout or other error, retry
continue
}
response := &stun.Message{Raw: buf[:n]}
if err := response.Decode(); err != nil {
continue
}
if response.Type == stun.BindingSuccess {
return // Server is ready
}
}
t.Fatalf("server did not become ready within %v", timeout)
}
func TestServer_BindingRequest(t *testing.T) {
// Start the STUN server on a random port
server, listener, serverAddr := createTestServer(t)
// Start server in background
serverErrCh := make(chan error, 1)
go func() {
serverErrCh <- server.Listen()
}()
// Wait for server to be ready
waitForServerReady(t, serverAddr, 2*time.Second)
// Create a UDP client
clientConn, err := net.DialUDP("udp", nil, serverAddr)
require.NoError(t, err)
defer clientConn.Close()
// Build a STUN binding request
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
require.NoError(t, err)
// Send the request
_, err = clientConn.Write(msg.Raw)
require.NoError(t, err)
// Read the response
buf := make([]byte, 1500)
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := clientConn.Read(buf)
require.NoError(t, err)
// Parse the response
response := &stun.Message{Raw: buf[:n]}
err = response.Decode()
require.NoError(t, err)
// Verify it's a binding success
assert.Equal(t, stun.BindingSuccess, response.Type)
// Extract the XOR-MAPPED-ADDRESS
var xorAddr stun.XORMappedAddress
err = xorAddr.GetFrom(response)
require.NoError(t, err)
// Verify the address matches our client's local address
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
assert.Equal(t, clientAddr.IP.String(), xorAddr.IP.String())
assert.Equal(t, clientAddr.Port, xorAddr.Port)
// Close listener first to unblock readLoop, then shutdown
_ = listener.Close()
err = server.Shutdown()
require.NoError(t, err)
}
func TestServer_IgnoresNonSTUNPackets(t *testing.T) {
server, listener, serverAddr := createTestServer(t)
go func() {
_ = server.Listen()
}()
waitForServerReady(t, serverAddr, 2*time.Second)
clientConn, err := net.DialUDP("udp", nil, serverAddr)
require.NoError(t, err)
defer clientConn.Close()
// Send non-STUN data
_, err = clientConn.Write([]byte("hello world"))
require.NoError(t, err)
// Try to read response (should timeout since server ignores non-STUN)
buf := make([]byte, 1500)
_ = clientConn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
_, err = clientConn.Read(buf)
assert.Error(t, err) // Should be a timeout error
// Close listener first to unblock readLoop, then shutdown
_ = listener.Close()
_ = server.Shutdown()
}
func TestServer_Shutdown(t *testing.T) {
server, listener, serverAddr := createTestServer(t)
serverDone := make(chan struct{})
go func() {
err := server.Listen()
assert.True(t, errors.Is(err, ErrServerClosed))
close(serverDone)
}()
waitForServerReady(t, serverAddr, 2*time.Second)
// Close listener first to unblock readLoop, then shutdown
_ = listener.Close()
err := server.Shutdown()
require.NoError(t, err)
// Wait for Listen to return
select {
case <-serverDone:
// Success
case <-time.After(3 * time.Second):
t.Fatal("server did not shutdown in time")
}
}
func TestServer_MultipleRequests(t *testing.T) {
server, listener, serverAddr := createTestServer(t)
go func() {
_ = server.Listen()
}()
waitForServerReady(t, serverAddr, 2*time.Second)
// Create multiple clients and send requests
for i := 0; i < 5; i++ {
func() {
clientConn, err := net.DialUDP("udp", nil, serverAddr)
require.NoError(t, err)
defer clientConn.Close()
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
require.NoError(t, err)
_, err = clientConn.Write(msg.Raw)
require.NoError(t, err)
buf := make([]byte, 1500)
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := clientConn.Read(buf)
require.NoError(t, err)
response := &stun.Message{Raw: buf[:n]}
err = response.Decode()
require.NoError(t, err)
assert.Equal(t, stun.BindingSuccess, response.Type)
}()
}
// Close listener first to unblock readLoop, then shutdown
_ = listener.Close()
_ = server.Shutdown()
}
func TestServer_ConcurrentClients(t *testing.T) {
numClients := 100
requestsPerClient := 5
maxStartDelay := 100 * time.Millisecond // Random delay before client starts
maxRequestDelay := 500 * time.Millisecond // Random delay between requests
// Remote server to test against via env var STUN_TEST_SERVER
// Example: STUN_TEST_SERVER=example.netbird.io:3478 go test -v ./stun/... -run ConcurrentClients
remoteServer := os.Getenv("STUN_TEST_SERVER")
var serverAddr *net.UDPAddr
var server *Server
var listener *net.UDPConn
if remoteServer != "" {
// Use remote server
var err error
serverAddr, err = net.ResolveUDPAddr("udp", remoteServer)
require.NoError(t, err)
t.Logf("Testing against remote server: %s", remoteServer)
} else {
// Start local server
server, listener, serverAddr = createTestServer(t)
go func() {
_ = server.Listen()
}()
waitForServerReady(t, serverAddr, 2*time.Second)
t.Logf("Testing against local server: %s", serverAddr)
}
var wg sync.WaitGroup
errorz := make(chan error, numClients*requestsPerClient)
successCount := make(chan int, numClients)
startTime := time.Now()
for i := 0; i < numClients; i++ {
wg.Add(1)
go func(clientID int) {
defer wg.Done()
// Random delay before starting
time.Sleep(time.Duration(rand.Int63n(int64(maxStartDelay))))
clientConn, err := net.DialUDP("udp", nil, serverAddr)
if err != nil {
errorz <- fmt.Errorf("client %d: failed to dial: %w", clientID, err)
return
}
defer clientConn.Close()
success := 0
for j := 0; j < requestsPerClient; j++ {
// Random delay between requests
if j > 0 {
time.Sleep(time.Duration(rand.Int63n(int64(maxRequestDelay))))
}
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
if err != nil {
errorz <- fmt.Errorf("client %d: failed to build request: %w", clientID, err)
continue
}
_, err = clientConn.Write(msg.Raw)
if err != nil {
errorz <- fmt.Errorf("client %d: failed to write: %w", clientID, err)
continue
}
buf := make([]byte, 1500)
_ = clientConn.SetReadDeadline(time.Now().Add(5 * time.Second))
n, err := clientConn.Read(buf)
if err != nil {
errorz <- fmt.Errorf("client %d: failed to read: %w", clientID, err)
continue
}
response := &stun.Message{Raw: buf[:n]}
if err := response.Decode(); err != nil {
errorz <- fmt.Errorf("client %d: failed to decode: %w", clientID, err)
continue
}
if response.Type != stun.BindingSuccess {
errorz <- fmt.Errorf("client %d: unexpected response type: %s", clientID, response.Type)
continue
}
success++
}
successCount <- success
}(i)
}
wg.Wait()
close(errorz)
close(successCount)
elapsed := time.Since(startTime)
totalSuccess := 0
for count := range successCount {
totalSuccess += count
}
var errs []error
for err := range errorz {
errs = append(errs, err)
}
totalRequests := numClients * requestsPerClient
t.Logf("Completed %d/%d requests in %v (%.2f req/s)",
totalSuccess, totalRequests, elapsed,
float64(totalSuccess)/elapsed.Seconds())
if len(errs) > 0 {
t.Logf("Errors (%d):", len(errs))
for i, err := range errs {
if i < 10 { // Only show first 10 errors
t.Logf(" - %v", err)
}
}
}
// Require at least 95% success rate
successRate := float64(totalSuccess) / float64(totalRequests)
require.GreaterOrEqual(t, successRate, 0.95, "success rate too low: %.2f%%", successRate*100)
// Cleanup local server if used
if server != nil {
// Close listener first to unblock readLoop, then shutdown
_ = listener.Close()
_ = server.Shutdown()
}
}
func TestServer_MultiplePorts(t *testing.T) {
// Create listeners on two random ports
conn1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
require.NoError(t, err)
conn2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
require.NoError(t, err)
addr1 := conn1.LocalAddr().(*net.UDPAddr)
addr2 := conn2.LocalAddr().(*net.UDPAddr)
server := NewServer([]*net.UDPConn{conn1, conn2}, "debug")
go func() {
_ = server.Listen()
}()
// Wait for server to be ready (checking first port is sufficient)
waitForServerReady(t, addr1, 2*time.Second)
// Test requests on both ports
for _, serverAddr := range []*net.UDPAddr{addr1, addr2} {
func() {
clientConn, err := net.DialUDP("udp", nil, serverAddr)
require.NoError(t, err)
defer clientConn.Close()
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
require.NoError(t, err)
_, err = clientConn.Write(msg.Raw)
require.NoError(t, err)
buf := make([]byte, 1500)
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := clientConn.Read(buf)
require.NoError(t, err)
response := &stun.Message{Raw: buf[:n]}
err = response.Decode()
require.NoError(t, err)
assert.Equal(t, stun.BindingSuccess, response.Type)
var xorAddr stun.XORMappedAddress
err = xorAddr.GetFrom(response)
require.NoError(t, err)
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
assert.Equal(t, clientAddr.Port, xorAddr.Port)
}()
}
// Close listeners first to unblock readLoops, then shutdown
_ = conn1.Close()
_ = conn2.Close()
_ = server.Shutdown()
}
// BenchmarkSTUNServer benchmarks the STUN server with concurrent clients
func BenchmarkSTUNServer(b *testing.B) {
server, listener, serverAddr := createTestServer(b)
go func() {
_ = server.Listen()
}()
waitForServerReady(b, serverAddr, 2*time.Second)
// Capture first error atomically - b.Fatal cannot be called from worker goroutines
var firstErr atomic.Pointer[error]
setErr := func(err error) {
firstErr.CompareAndSwap(nil, &err)
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
// Stop work if an error has occurred
if firstErr.Load() != nil {
return
}
clientConn, err := net.DialUDP("udp", nil, serverAddr)
if err != nil {
setErr(err)
return
}
defer clientConn.Close()
buf := make([]byte, 1500)
for pb.Next() {
if firstErr.Load() != nil {
return
}
msg, _ := stun.Build(stun.TransactionID, stun.BindingRequest)
_, _ = clientConn.Write(msg.Raw)
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err := clientConn.Read(buf)
if err != nil {
setErr(err)
return
}
response := &stun.Message{Raw: buf[:n]}
if err := response.Decode(); err != nil {
setErr(err)
return
}
}
})
b.StopTimer()
// Fail after RunParallel completes
if errPtr := firstErr.Load(); errPtr != nil {
b.Fatal(*errPtr)
}
// Close listener first to unblock readLoop, then shutdown
_ = listener.Close()
_ = server.Shutdown()
}