mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-07 09:19:59 +00:00
Compare commits
7 Commits
ci-win-tes
...
wasm-debug
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
063cbdc6d8 | ||
|
|
291e640b28 | ||
|
|
efb954b7d6 | ||
|
|
cac9326d3d | ||
|
|
520d9c66cf | ||
|
|
ff10498a8b | ||
|
|
00b747ad5d |
@@ -314,9 +314,8 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||
profName = activeProf.Name
|
||||
}
|
||||
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName),
|
||||
)
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName)
|
||||
statusOutputString = overview.FullDetailSummary()
|
||||
}
|
||||
return statusOutputString
|
||||
}
|
||||
|
||||
@@ -103,13 +103,13 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
||||
var statusOutputString string
|
||||
switch {
|
||||
case detailFlag:
|
||||
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
||||
statusOutputString = outputInformationHolder.FullDetailSummary()
|
||||
case jsonFlag:
|
||||
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
||||
statusOutputString, err = outputInformationHolder.JSON()
|
||||
case yamlFlag:
|
||||
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||
statusOutputString, err = outputInformationHolder.YAML()
|
||||
default:
|
||||
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
|
||||
statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -29,6 +31,11 @@ var (
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPeerConnectionTimeout = 60 * time.Second
|
||||
peerConnectionPollInterval = 500 * time.Millisecond
|
||||
)
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
deviceName string
|
||||
@@ -38,6 +45,7 @@ type Client struct {
|
||||
setupKey string
|
||||
jwtToken string
|
||||
connect *internal.ConnectClient
|
||||
recorder *peer.Status
|
||||
}
|
||||
|
||||
// Options configures a new Client.
|
||||
@@ -161,11 +169,17 @@ func New(opts Options) (*Client, error) {
|
||||
func (c *Client) Start(startCtx context.Context) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.cancel != nil {
|
||||
if c.connect != nil {
|
||||
return ErrClientAlreadyStarted
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(context.Background())
|
||||
ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background()))
|
||||
defer func() {
|
||||
if c.connect == nil {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
|
||||
@@ -173,7 +187,9 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||
c.recorder = recorder
|
||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
// TODO: make after-startup backoff err available
|
||||
@@ -197,6 +213,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
}
|
||||
|
||||
c.connect = client
|
||||
c.cancel = cancel
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -211,17 +228,23 @@ func (c *Client) Stop(ctx context.Context) error {
|
||||
return ErrClientNotStarted
|
||||
}
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
c.cancel = nil
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
connect := c.connect
|
||||
go func() {
|
||||
done <- c.connect.Stop()
|
||||
done <- connect.Stop()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.cancel = nil
|
||||
c.connect = nil
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
c.cancel = nil
|
||||
c.connect = nil
|
||||
if err != nil {
|
||||
return fmt.Errorf("stop: %w", err)
|
||||
}
|
||||
@@ -241,18 +264,40 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
// With lazy connections, the connection is established on first traffic.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
logrus.Infof("embed.Dial called: network=%s, address=%s", network, address)
|
||||
|
||||
// Check context status upfront
|
||||
if ctx.Err() != nil {
|
||||
logrus.Warnf("embed.Dial: context already cancelled/expired: %v", ctx.Err())
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
logrus.Errorf("embed.Dial: getEngine failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
if err != nil {
|
||||
logrus.Errorf("embed.Dial: GetNet failed: %v", err)
|
||||
return nil, fmt.Errorf("get net: %w", err)
|
||||
}
|
||||
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
// Note: Don't wait for peer connection here - lazy connection manager
|
||||
// will open the connection when DialContext is called. The netstack
|
||||
// dial triggers WireGuard traffic which activates the lazy connection.
|
||||
|
||||
logrus.Debugf("embed.Dial: calling nsnet.DialContext for %s", address)
|
||||
conn, err := nsnet.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
logrus.Errorf("embed.Dial: nsnet.DialContext failed: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
logrus.Infof("embed.Dial: successfully connected to %s", address)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// DialContext dials a network address in the netbird network with context
|
||||
@@ -315,6 +360,90 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
recorder := c.recorder
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
if recorder == nil {
|
||||
return peer.FullStatus{}, errors.New("client not started")
|
||||
}
|
||||
|
||||
if connect != nil {
|
||||
engine := connect.Engine()
|
||||
if engine != nil {
|
||||
_ = engine.RunHealthProbes(false)
|
||||
}
|
||||
}
|
||||
|
||||
return recorder.GetFullStatus(), nil
|
||||
}
|
||||
|
||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||
func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
syncResp, err := engine.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get sync response: %w", err)
|
||||
}
|
||||
|
||||
return syncResp, nil
|
||||
}
|
||||
|
||||
// WaitForPeerConnection waits for a peer with the given IP to be connected.
|
||||
func (c *Client) WaitForPeerConnection(ctx context.Context, peerIP string) error {
|
||||
logrus.Infof("Waiting for peer %s to be connected", peerIP)
|
||||
|
||||
ticker := time.NewTicker(peerConnectionPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("timeout waiting for peer %s to connect: %w", peerIP, ctx.Err())
|
||||
case <-ticker.C:
|
||||
status, err := c.Status()
|
||||
if err != nil {
|
||||
logrus.Debugf("Error getting status while waiting for peer: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, p := range status.Peers {
|
||||
if p.IP == peerIP && p.ConnStatus == peer.StatusConnected {
|
||||
logrus.Infof("Peer %s is now connected (relayed: %v)", peerIP, p.Relayed)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
logrus.Tracef("Peer %s not yet connected, waiting...", peerIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogLevel sets the logging level for the client and its components.
|
||||
func (c *Client) SetLogLevel(levelStr string) error {
|
||||
level, err := logrus.ParseLevel(levelStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse log level: %w", err)
|
||||
}
|
||||
|
||||
logrus.SetLevel(level)
|
||||
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
c.mu.Unlock()
|
||||
|
||||
// Note: ConnectClient doesn't have SetLogLevel method
|
||||
_ = connect
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||
|
||||
@@ -23,10 +23,10 @@ func NewNSDialer(net *netstack.Net) *NSDialer {
|
||||
}
|
||||
|
||||
func (d *NSDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
log.Debugf("dialing %s %s", network, addr)
|
||||
log.Infof("NSDialer.Dial: network=%s, addr=%s", network, addr)
|
||||
conn, err := d.net.Dial(network, addr)
|
||||
if err != nil {
|
||||
log.Debugf("failed to deal connection: %s", err)
|
||||
log.Warnf("NSDialer.Dial failed: %s", err)
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
@@ -420,6 +420,19 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error)
|
||||
return syncResponse, nil
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager if the engine is running.
|
||||
func (c *ConnectClient) SetLogLevel(level log.Level) {
|
||||
engine := c.Engine()
|
||||
if engine == nil {
|
||||
return
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager != nil {
|
||||
fwManager.SetLogLevel(level)
|
||||
}
|
||||
}
|
||||
|
||||
// Status returns the current client status
|
||||
func (c *ConnectClient) Status() StatusType {
|
||||
if c == nil {
|
||||
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
|
||||
const (
|
||||
PriorityMgmtCache = 150
|
||||
PriorityLocal = 100
|
||||
PriorityDNSRoute = 75
|
||||
PriorityDNSRoute = 100
|
||||
PriorityLocal = 75
|
||||
PriorityUpstream = 50
|
||||
PriorityDefault = 1
|
||||
PriorityFallback = -100
|
||||
|
||||
@@ -631,9 +631,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.wgInterface,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
nbdns.RootZone,
|
||||
@@ -743,9 +741,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
||||
log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority)
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.wgInterface,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
domainGroup.domain,
|
||||
@@ -926,9 +922,7 @@ func (s *DefaultServer) addHostRootZone() {
|
||||
|
||||
handler, err := newUpstreamResolver(
|
||||
s.ctx,
|
||||
s.wgInterface.Name(),
|
||||
s.wgInterface.Address().IP,
|
||||
s.wgInterface.Address().Network,
|
||||
s.wgInterface,
|
||||
s.statusRecorder,
|
||||
s.hostsDNSHolder,
|
||||
nbdns.RootZone,
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||
@@ -81,6 +82,10 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
|
||||
return configurer.WGStats{}, nil
|
||||
}
|
||||
|
||||
func (w *mocWGIface) GetNet() *netstack.Net {
|
||||
return nil
|
||||
}
|
||||
|
||||
var zoneRecords = []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: "peera.netbird.cloud",
|
||||
@@ -180,7 +185,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
|
||||
},
|
||||
{
|
||||
name: "New Config Should Succeed",
|
||||
name: "New Config Should Succeed",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
@@ -221,19 +226,19 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
|
||||
},
|
||||
{
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
name: "Smaller Config Serial Should Be Skipped",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 2,
|
||||
inputSerial: 1,
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -251,11 +256,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Invalid NS Group Nameservers list Should Fail",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -273,11 +278,11 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
name: "Invalid Custom Zone Records list Should Skip",
|
||||
initLocalZones: []nbdns.CustomZone{},
|
||||
initUpstreamMap: make(registeredHandlerMap),
|
||||
initSerial: 0,
|
||||
inputSerial: 1,
|
||||
inputUpdate: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
@@ -299,7 +304,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
name: "Empty Config Should Succeed and Clean Maps",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
@@ -315,7 +320,7 @@ func TestUpdateDNSServer(t *testing.T) {
|
||||
expectedLocalQs: []dns.Question{},
|
||||
},
|
||||
{
|
||||
name: "Disabled Service Should clean map",
|
||||
name: "Disabled Service Should clean map",
|
||||
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
|
||||
initUpstreamMap: registeredHandlerMap{
|
||||
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
|
||||
@@ -2047,7 +2052,7 @@ func TestLocalResolverPriorityInServer(t *testing.T) {
|
||||
|
||||
func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
// Test that priority constants are ordered correctly
|
||||
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
|
||||
assert.Greater(t, PriorityDNSRoute, PriorityLocal, "DNS Route should be higher than Local priority")
|
||||
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
|
||||
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
@@ -418,6 +419,56 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If response is truncated, retry with TCP
|
||||
if reply != nil && reply.MsgHdr.Truncated {
|
||||
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP",
|
||||
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||
return netstackExchange(ctx, nsNet, r, upstream, "tcp")
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) {
|
||||
conn, err := nsNet.DialContext(ctx, network, upstream)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("with %s: %w", network, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close DNS connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := conn.SetDeadline(deadline); err != nil {
|
||||
return nil, fmt.Errorf("set deadline: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
dnsConn := &dns.Conn{Conn: conn}
|
||||
|
||||
if err := dnsConn.WriteMsg(r); err != nil {
|
||||
return nil, fmt.Errorf("write %s message: %w", network, err)
|
||||
}
|
||||
|
||||
reply, err := dnsConn.ReadMsg()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read %s message: %w", network, err)
|
||||
}
|
||||
|
||||
return reply, nil
|
||||
}
|
||||
|
||||
|
||||
// FormatPeerStatus formats peer connection status information for debugging DNS timeouts
|
||||
func FormatPeerStatus(peerState *peer.State) string {
|
||||
isConnected := peerState.ConnStatus == peer.StatusConnected
|
||||
|
||||
@@ -23,9 +23,7 @@ type upstreamResolver struct {
|
||||
// first time, and we need to wait for a while to start to use again the proper DNS resolver.
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
_ WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
hostsDNSHolder *hostsDNSHolder,
|
||||
domain string,
|
||||
|
||||
@@ -5,22 +5,23 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
*upstreamResolverBase
|
||||
nsNet *netstack.Net
|
||||
}
|
||||
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
_ string,
|
||||
_ netip.Addr,
|
||||
_ netip.Prefix,
|
||||
wgIface WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
@@ -28,12 +29,23 @@ func newUpstreamResolver(
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
nonIOS := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
nsNet: wgIface.GetNet(),
|
||||
}
|
||||
upstreamResolverBase.upstreamClient = nonIOS
|
||||
return nonIOS, nil
|
||||
}
|
||||
|
||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||
// TODO: Check if upstream DNS server is routed through a peer before using netstack.
|
||||
// Similar to iOS logic, we should determine if the DNS server is reachable directly
|
||||
// or needs to go through the tunnel, and only use netstack when necessary.
|
||||
// For now, only use netstack on JS platform where direct access is not possible.
|
||||
if u.nsNet != nil && runtime.GOOS == "js" {
|
||||
start := time.Now()
|
||||
reply, err := ExchangeWithNetstack(ctx, u.nsNet, r, upstream)
|
||||
return reply, time.Since(start), err
|
||||
}
|
||||
|
||||
client := &dns.Client{
|
||||
Timeout: ClientTimeout,
|
||||
}
|
||||
|
||||
@@ -26,9 +26,7 @@ type upstreamResolverIOS struct {
|
||||
|
||||
func newUpstreamResolver(
|
||||
ctx context.Context,
|
||||
interfaceName string,
|
||||
ip netip.Addr,
|
||||
net netip.Prefix,
|
||||
wgIface WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
@@ -37,9 +35,9 @@ func newUpstreamResolver(
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
lIP: ip,
|
||||
lNet: net,
|
||||
interfaceName: interfaceName,
|
||||
lIP: wgIface.Address().IP,
|
||||
lNet: wgIface.Address().Network,
|
||||
interfaceName: wgIface.Name(),
|
||||
}
|
||||
ios.upstreamClient = ios
|
||||
|
||||
|
||||
@@ -2,13 +2,17 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||
)
|
||||
|
||||
@@ -58,7 +62,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.TODO())
|
||||
resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
|
||||
resolver, _ := newUpstreamResolver(ctx, &mockNetstackProvider{}, nil, nil, ".")
|
||||
// Convert test servers to netip.AddrPort
|
||||
var servers []netip.AddrPort
|
||||
for _, server := range testCase.InputServers {
|
||||
@@ -112,6 +116,19 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type mockNetstackProvider struct{}
|
||||
|
||||
func (m *mockNetstackProvider) Name() string { return "mock" }
|
||||
func (m *mockNetstackProvider) Address() wgaddr.Address { return wgaddr.Address{} }
|
||||
func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil }
|
||||
func (m *mockNetstackProvider) IsUserspaceBind() bool { return false }
|
||||
func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil }
|
||||
func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil }
|
||||
func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil }
|
||||
func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type mockUpstreamResolver struct {
|
||||
r *dns.Msg
|
||||
rtt time.Duration
|
||||
|
||||
@@ -5,6 +5,8 @@ package dns
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -17,4 +19,5 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -12,5 +14,6 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetNet() *netstack.Net
|
||||
GetInterfaceGUIDString() (string, error)
|
||||
}
|
||||
|
||||
@@ -1748,22 +1748,26 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
}
|
||||
|
||||
e.syncMsgMux.Unlock()
|
||||
var results []relay.ProbeResult
|
||||
if waitForResult {
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||
} else {
|
||||
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
|
||||
}
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
// Skip STUN/TURN probing for JS/WASM as it's not available
|
||||
relayHealthy := true
|
||||
for _, res := range results {
|
||||
if res.Err != nil {
|
||||
relayHealthy = false
|
||||
break
|
||||
if runtime.GOOS != "js" {
|
||||
var results []relay.ProbeResult
|
||||
if waitForResult {
|
||||
results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns)
|
||||
} else {
|
||||
results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns)
|
||||
}
|
||||
e.statusRecorder.UpdateRelayStates(results)
|
||||
|
||||
for _, res := range results {
|
||||
if res.Err != nil {
|
||||
relayHealthy = false
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
||||
}
|
||||
log.Debugf("relay health check: healthy=%t", relayHealthy)
|
||||
|
||||
allHealthy := signalHealthy && managementHealthy && relayHealthy
|
||||
log.Debugf("all health checks completed: healthy=%t", allHealthy)
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -158,6 +159,7 @@ type FullStatus struct {
|
||||
NSGroupStates []NSGroupState
|
||||
NumOfForwardingRules int
|
||||
LazyConnectionEnabled bool
|
||||
Events []*proto.SystemEvent
|
||||
}
|
||||
|
||||
type StatusChangeSubscription struct {
|
||||
@@ -981,6 +983,7 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
}
|
||||
|
||||
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
|
||||
fullStatus.Events = d.GetEventHistory()
|
||||
return fullStatus
|
||||
}
|
||||
|
||||
@@ -1181,3 +1184,97 @@ type EventSubscription struct {
|
||||
func (s *EventSubscription) Events() <-chan *proto.SystemEvent {
|
||||
return s.events
|
||||
}
|
||||
|
||||
// ToProto converts FullStatus to proto.FullStatus.
|
||||
func (fs FullStatus) ToProto() *proto.FullStatus {
|
||||
pbFullStatus := proto.FullStatus{
|
||||
ManagementState: &proto.ManagementState{},
|
||||
SignalState: &proto.SignalState{},
|
||||
LocalPeerState: &proto.LocalPeerState{},
|
||||
Peers: []*proto.PeerState{},
|
||||
}
|
||||
|
||||
pbFullStatus.ManagementState.URL = fs.ManagementState.URL
|
||||
pbFullStatus.ManagementState.Connected = fs.ManagementState.Connected
|
||||
if err := fs.ManagementState.Error; err != nil {
|
||||
pbFullStatus.ManagementState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.SignalState.URL = fs.SignalState.URL
|
||||
pbFullStatus.SignalState.Connected = fs.SignalState.Connected
|
||||
if err := fs.SignalState.Error; err != nil {
|
||||
pbFullStatus.SignalState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP
|
||||
pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey
|
||||
pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface
|
||||
pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN
|
||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive
|
||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled
|
||||
pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules)
|
||||
pbFullStatus.LazyConnectionEnabled = fs.LazyConnectionEnabled
|
||||
|
||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fs.LocalPeerState.Routes)
|
||||
|
||||
for _, peerState := range fs.Peers {
|
||||
networks := maps.Keys(peerState.GetRoutes())
|
||||
|
||||
pbPeerState := &proto.PeerState{
|
||||
IP: peerState.IP,
|
||||
PubKey: peerState.PubKey,
|
||||
ConnStatus: peerState.ConnStatus.String(),
|
||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||
Relayed: peerState.Relayed,
|
||||
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
||||
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
||||
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
||||
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
||||
RelayAddress: peerState.RelayServerAddress,
|
||||
Fqdn: peerState.FQDN,
|
||||
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
||||
BytesRx: peerState.BytesRx,
|
||||
BytesTx: peerState.BytesTx,
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: networks,
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
SshHostKey: peerState.SSHHostKey,
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
for _, relayState := range fs.Relays {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
}
|
||||
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
||||
}
|
||||
|
||||
for _, dnsState := range fs.NSGroupStates {
|
||||
var err string
|
||||
if dnsState.Error != nil {
|
||||
err = dnsState.Error.Error()
|
||||
}
|
||||
|
||||
var servers []string
|
||||
for _, server := range dnsState.Servers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
|
||||
pbDnsState := &proto.NSGroupState{
|
||||
Servers: servers,
|
||||
Domains: dnsState.Domains,
|
||||
Enabled: dnsState.Enabled,
|
||||
Error: err,
|
||||
}
|
||||
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
||||
}
|
||||
|
||||
pbFullStatus.Events = fs.Events
|
||||
|
||||
return &pbFullStatus
|
||||
}
|
||||
|
||||
@@ -17,13 +17,13 @@ import (
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
iface "github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
@@ -38,11 +38,6 @@ type internalDNATer interface {
|
||||
AddInternalDNATMapping(netip.Addr, netip.Addr) error
|
||||
}
|
||||
|
||||
type wgInterface interface {
|
||||
Name() string
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
type DnsInterceptor struct {
|
||||
mu sync.RWMutex
|
||||
route *route.Route
|
||||
@@ -52,7 +47,7 @@ type DnsInterceptor struct {
|
||||
dnsServer nbdns.Server
|
||||
currentPeerKey string
|
||||
interceptedDomains domainMap
|
||||
wgInterface wgInterface
|
||||
wgInterface iface.WGIface
|
||||
peerStore *peerstore.Store
|
||||
firewall firewall.Manager
|
||||
fakeIPManager *fakeip.Manager
|
||||
@@ -250,12 +245,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
||||
if err != nil {
|
||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if r.Extra == nil {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
@@ -264,20 +253,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
startTime := time.Now()
|
||||
reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
elapsed := time.Since(startTime)
|
||||
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
|
||||
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
|
||||
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
|
||||
} else {
|
||||
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
}
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
reply := d.queryUpstreamDNS(ctx, w, r, upstream, upstreamIP, peerKey, logger)
|
||||
if reply == nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -586,6 +563,44 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR
|
||||
return
|
||||
}
|
||||
|
||||
// queryUpstreamDNS queries the upstream DNS server using netstack if available, otherwise uses regular client.
|
||||
// Returns the DNS reply on success, or nil on error (error responses are written internally).
|
||||
func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream string, upstreamIP netip.Addr, peerKey string, logger *log.Entry) *dns.Msg {
|
||||
startTime := time.Now()
|
||||
|
||||
nsNet := d.wgInterface.GetNet()
|
||||
var reply *dns.Msg
|
||||
var err error
|
||||
|
||||
if nsNet != nil {
|
||||
reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream)
|
||||
} else {
|
||||
client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout)
|
||||
if clientErr != nil {
|
||||
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr))
|
||||
return nil
|
||||
}
|
||||
reply, _, err = nbdns.ExchangeWithFallback(ctx, client, r, upstream)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return reply
|
||||
}
|
||||
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
elapsed := time.Since(startTime)
|
||||
peerInfo := d.debugPeerTimeout(upstreamIP, peerKey)
|
||||
logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v",
|
||||
elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err)
|
||||
} else {
|
||||
logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
|
||||
}
|
||||
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
|
||||
logger.Errorf("failed writing DNS response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string {
|
||||
if d.statusRecorder == nil {
|
||||
return ""
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -18,4 +20,5 @@ type wgIfaceBase interface {
|
||||
IsUserspaceBind() bool
|
||||
GetFilter() device.PacketFilter
|
||||
GetDevice() *device.FilteredDevice
|
||||
GetNet() *netstack.Net
|
||||
}
|
||||
|
||||
@@ -173,20 +173,9 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (
|
||||
|
||||
log.SetLevel(level)
|
||||
|
||||
if s.connectClient == nil {
|
||||
return nil, fmt.Errorf("connect client not initialized")
|
||||
if s.connectClient != nil {
|
||||
s.connectClient.SetLogLevel(level)
|
||||
}
|
||||
engine := s.connectClient.Engine()
|
||||
if engine == nil {
|
||||
return nil, fmt.Errorf("engine not initialized")
|
||||
}
|
||||
|
||||
fwManager := engine.GetFirewallManager()
|
||||
if fwManager == nil {
|
||||
return nil, fmt.Errorf("firewall manager not initialized")
|
||||
}
|
||||
|
||||
fwManager.SetLogLevel(level)
|
||||
|
||||
log.Infof("Log level set to %s", level.String())
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
@@ -29,8 +27,3 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) {
|
||||
events := s.statusRecorder.GetEventHistory()
|
||||
return &proto.GetEventsResponse{Events: events}, nil
|
||||
}
|
||||
|
||||
@@ -13,15 +13,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
@@ -1067,11 +1064,9 @@ func (s *Server) Status(
|
||||
if msg.GetFullPeerStatus {
|
||||
s.runProbes(msg.ShouldRunProbes)
|
||||
fullStatus := s.statusRecorder.GetFullStatus()
|
||||
pbFullStatus := toProtoFullStatus(fullStatus)
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
|
||||
|
||||
pbFullStatus.SshServerState = s.getSSHServerState()
|
||||
|
||||
statusResponse.FullStatus = pbFullStatus
|
||||
}
|
||||
|
||||
@@ -1600,94 +1595,6 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio
|
||||
return defaultDuration
|
||||
}
|
||||
|
||||
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
|
||||
pbFullStatus := proto.FullStatus{
|
||||
ManagementState: &proto.ManagementState{},
|
||||
SignalState: &proto.SignalState{},
|
||||
LocalPeerState: &proto.LocalPeerState{},
|
||||
Peers: []*proto.PeerState{},
|
||||
}
|
||||
|
||||
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
|
||||
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
|
||||
if err := fullStatus.ManagementState.Error; err != nil {
|
||||
pbFullStatus.ManagementState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
|
||||
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
|
||||
if err := fullStatus.SignalState.Error; err != nil {
|
||||
pbFullStatus.SignalState.Error = err.Error()
|
||||
}
|
||||
|
||||
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
|
||||
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
|
||||
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
|
||||
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
|
||||
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
|
||||
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
|
||||
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
|
||||
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
|
||||
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
|
||||
|
||||
for _, peerState := range fullStatus.Peers {
|
||||
pbPeerState := &proto.PeerState{
|
||||
IP: peerState.IP,
|
||||
PubKey: peerState.PubKey,
|
||||
ConnStatus: peerState.ConnStatus.String(),
|
||||
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
|
||||
Relayed: peerState.Relayed,
|
||||
LocalIceCandidateType: peerState.LocalIceCandidateType,
|
||||
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
|
||||
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
|
||||
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
|
||||
RelayAddress: peerState.RelayServerAddress,
|
||||
Fqdn: peerState.FQDN,
|
||||
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
|
||||
BytesRx: peerState.BytesRx,
|
||||
BytesTx: peerState.BytesTx,
|
||||
RosenpassEnabled: peerState.RosenpassEnabled,
|
||||
Networks: maps.Keys(peerState.GetRoutes()),
|
||||
Latency: durationpb.New(peerState.Latency),
|
||||
SshHostKey: peerState.SSHHostKey,
|
||||
}
|
||||
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
|
||||
}
|
||||
|
||||
for _, relayState := range fullStatus.Relays {
|
||||
pbRelayState := &proto.RelayState{
|
||||
URI: relayState.URI,
|
||||
Available: relayState.Err == nil,
|
||||
}
|
||||
if err := relayState.Err; err != nil {
|
||||
pbRelayState.Error = err.Error()
|
||||
}
|
||||
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
|
||||
}
|
||||
|
||||
for _, dnsState := range fullStatus.NSGroupStates {
|
||||
var err string
|
||||
if dnsState.Error != nil {
|
||||
err = dnsState.Error.Error()
|
||||
}
|
||||
|
||||
var servers []string
|
||||
for _, server := range dnsState.Servers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
|
||||
pbDnsState := &proto.NSGroupState{
|
||||
Servers: servers,
|
||||
Domains: dnsState.Domains,
|
||||
Enabled: dnsState.Enabled,
|
||||
Error: err,
|
||||
}
|
||||
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
|
||||
}
|
||||
|
||||
return &pbFullStatus
|
||||
}
|
||||
|
||||
// sendTerminalNotification sends a terminal notification message
|
||||
// to inform the user that the NetBird connection session has expired.
|
||||
func sendTerminalNotification() error {
|
||||
|
||||
@@ -325,61 +325,64 @@ func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) {
|
||||
}
|
||||
}
|
||||
|
||||
func ParseToJSON(overview OutputOverview) (string, error) {
|
||||
jsonBytes, err := json.Marshal(overview)
|
||||
// JSON returns the status overview as a JSON string.
|
||||
func (o *OutputOverview) JSON() (string, error) {
|
||||
jsonBytes, err := json.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("json marshal failed")
|
||||
}
|
||||
return string(jsonBytes), err
|
||||
}
|
||||
|
||||
func ParseToYAML(overview OutputOverview) (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(overview)
|
||||
// YAML returns the status overview as a YAML string.
|
||||
func (o *OutputOverview) YAML() (string, error) {
|
||||
yamlBytes, err := yaml.Marshal(o)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("yaml marshal failed")
|
||||
}
|
||||
return string(yamlBytes), nil
|
||||
}
|
||||
|
||||
func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
||||
// GeneralSummary returns a general summary of the status overview.
|
||||
func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string {
|
||||
var managementConnString string
|
||||
if overview.ManagementState.Connected {
|
||||
if o.ManagementState.Connected {
|
||||
managementConnString = "Connected"
|
||||
if showURL {
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, o.ManagementState.URL)
|
||||
}
|
||||
} else {
|
||||
managementConnString = "Disconnected"
|
||||
if overview.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
||||
if o.ManagementState.Error != "" {
|
||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, o.ManagementState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
var signalConnString string
|
||||
if overview.SignalState.Connected {
|
||||
if o.SignalState.Connected {
|
||||
signalConnString = "Connected"
|
||||
if showURL {
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, o.SignalState.URL)
|
||||
}
|
||||
} else {
|
||||
signalConnString = "Disconnected"
|
||||
if overview.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
||||
if o.SignalState.Error != "" {
|
||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, o.SignalState.Error)
|
||||
}
|
||||
}
|
||||
|
||||
interfaceTypeString := "Userspace"
|
||||
interfaceIP := overview.IP
|
||||
if overview.KernelInterface {
|
||||
interfaceIP := o.IP
|
||||
if o.KernelInterface {
|
||||
interfaceTypeString = "Kernel"
|
||||
} else if overview.IP == "" {
|
||||
} else if o.IP == "" {
|
||||
interfaceTypeString = "N/A"
|
||||
interfaceIP = "N/A"
|
||||
}
|
||||
|
||||
var relaysString string
|
||||
if showRelays {
|
||||
for _, relay := range overview.Relays.Details {
|
||||
for _, relay := range o.Relays.Details {
|
||||
available := "Available"
|
||||
reason := ""
|
||||
|
||||
@@ -395,18 +398,18 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
||||
}
|
||||
} else {
|
||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
||||
relaysString = fmt.Sprintf("%d/%d Available", o.Relays.Available, o.Relays.Total)
|
||||
}
|
||||
|
||||
networks := "-"
|
||||
if len(overview.Networks) > 0 {
|
||||
sort.Strings(overview.Networks)
|
||||
networks = strings.Join(overview.Networks, ", ")
|
||||
if len(o.Networks) > 0 {
|
||||
sort.Strings(o.Networks)
|
||||
networks = strings.Join(o.Networks, ", ")
|
||||
}
|
||||
|
||||
var dnsServersString string
|
||||
if showNameServers {
|
||||
for _, nsServerGroup := range overview.NSServerGroups {
|
||||
for _, nsServerGroup := range o.NSServerGroups {
|
||||
enabled := "Available"
|
||||
if !nsServerGroup.Enabled {
|
||||
enabled = "Unavailable"
|
||||
@@ -430,25 +433,25 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(o.NSServerGroups), len(o.NSServerGroups))
|
||||
}
|
||||
|
||||
rosenpassEnabledStatus := "false"
|
||||
if overview.RosenpassEnabled {
|
||||
if o.RosenpassEnabled {
|
||||
rosenpassEnabledStatus = "true"
|
||||
if overview.RosenpassPermissive {
|
||||
if o.RosenpassPermissive {
|
||||
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
lazyConnectionEnabledStatus := "false"
|
||||
if overview.LazyConnectionEnabled {
|
||||
if o.LazyConnectionEnabled {
|
||||
lazyConnectionEnabledStatus = "true"
|
||||
}
|
||||
|
||||
sshServerStatus := "Disabled"
|
||||
if overview.SSHServerState.Enabled {
|
||||
sessionCount := len(overview.SSHServerState.Sessions)
|
||||
if o.SSHServerState.Enabled {
|
||||
sessionCount := len(o.SSHServerState.Sessions)
|
||||
if sessionCount > 0 {
|
||||
sessionWord := "session"
|
||||
if sessionCount > 1 {
|
||||
@@ -460,7 +463,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
}
|
||||
|
||||
if showSSHSessions && sessionCount > 0 {
|
||||
for _, session := range overview.SSHServerState.Sessions {
|
||||
for _, session := range o.SSHServerState.Sessions {
|
||||
var sessionDisplay string
|
||||
if session.JWTUsername != "" {
|
||||
sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s",
|
||||
@@ -484,7 +487,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
}
|
||||
}
|
||||
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
||||
peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total)
|
||||
|
||||
goos := runtime.GOOS
|
||||
goarch := runtime.GOARCH
|
||||
@@ -512,30 +515,31 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool,
|
||||
"Forwarding rules: %d\n"+
|
||||
"Peers count: %s\n",
|
||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
||||
overview.DaemonVersion,
|
||||
o.DaemonVersion,
|
||||
version.NetbirdVersion(),
|
||||
overview.ProfileName,
|
||||
o.ProfileName,
|
||||
managementConnString,
|
||||
signalConnString,
|
||||
relaysString,
|
||||
dnsServersString,
|
||||
domain.Domain(overview.FQDN).SafeString(),
|
||||
domain.Domain(o.FQDN).SafeString(),
|
||||
interfaceIP,
|
||||
interfaceTypeString,
|
||||
rosenpassEnabledStatus,
|
||||
lazyConnectionEnabledStatus,
|
||||
sshServerStatus,
|
||||
networks,
|
||||
overview.NumberOfForwardingRules,
|
||||
o.NumberOfForwardingRules,
|
||||
peersCountString,
|
||||
)
|
||||
return summary
|
||||
}
|
||||
|
||||
func ParseToFullDetailSummary(overview OutputOverview) string {
|
||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
||||
parsedEventsString := parseEvents(overview.Events)
|
||||
summary := ParseGeneralSummary(overview, true, true, true, true)
|
||||
// FullDetailSummary returns a full detailed summary with peer details and events.
|
||||
func (o *OutputOverview) FullDetailSummary() string {
|
||||
parsedPeersString := parsePeers(o.Peers, o.RosenpassEnabled, o.RosenpassPermissive)
|
||||
parsedEventsString := parseEvents(o.Events)
|
||||
summary := o.GeneralSummary(true, true, true, true)
|
||||
|
||||
return fmt.Sprintf(
|
||||
"Peers detail:"+
|
||||
|
||||
@@ -268,7 +268,7 @@ func TestSortingOfPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParsingToJSON(t *testing.T) {
|
||||
jsonString, _ := ParseToJSON(overview)
|
||||
jsonString, _ := overview.JSON()
|
||||
|
||||
//@formatter:off
|
||||
expectedJSONString := `
|
||||
@@ -404,7 +404,7 @@ func TestParsingToJSON(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParsingToYAML(t *testing.T) {
|
||||
yaml, _ := ParseToYAML(overview)
|
||||
yaml, _ := overview.YAML()
|
||||
|
||||
expectedYAML :=
|
||||
`peers:
|
||||
@@ -511,7 +511,7 @@ func TestParsingToDetail(t *testing.T) {
|
||||
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
||||
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
||||
|
||||
detail := ParseToFullDetailSummary(overview)
|
||||
detail := overview.FullDetailSummary()
|
||||
|
||||
expectedDetail := fmt.Sprintf(
|
||||
`Peers detail:
|
||||
@@ -575,7 +575,7 @@ Peers count: 2/2 Connected
|
||||
}
|
||||
|
||||
func TestParsingToShortVersion(t *testing.T) {
|
||||
shortVersion := ParseGeneralSummary(overview, false, false, false, false)
|
||||
shortVersion := overview.GeneralSummary(false, false, false, false)
|
||||
|
||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
||||
Daemon version: 0.14.1
|
||||
|
||||
@@ -441,7 +441,7 @@ func (s *serviceClient) collectDebugData(
|
||||
var postUpStatusOutput string
|
||||
if postUpStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
postUpStatusOutput = overview.FullDetailSummary()
|
||||
}
|
||||
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput)
|
||||
@@ -458,7 +458,7 @@ func (s *serviceClient) collectDebugData(
|
||||
var preDownStatusOutput string
|
||||
if preDownStatus != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName)
|
||||
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
preDownStatusOutput = overview.FullDetailSummary()
|
||||
}
|
||||
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
|
||||
time.Now().Format(time.RFC3339), params.duration)
|
||||
@@ -595,7 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
|
||||
var statusOutput string
|
||||
if statusResp != nil {
|
||||
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName)
|
||||
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
|
||||
statusOutput = overview.FullDetailSummary()
|
||||
}
|
||||
|
||||
request := &proto.DebugBundleRequest{
|
||||
|
||||
@@ -9,20 +9,31 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
clientStartTimeout = 30 * time.Second
|
||||
clientStopTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
clientStartTimeout = 30 * time.Second
|
||||
clientStopTimeout = 10 * time.Second
|
||||
pingTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
defaultPeerConnectionTimeout = 60 * time.Second
|
||||
peerConnectionPollInterval = 500 * time.Millisecond
|
||||
|
||||
icmpEchoRequest = 8
|
||||
icmpCodeEcho = 0
|
||||
pingBufferSize = 1500
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -113,18 +124,45 @@ func createStopMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// validateSSHArgs validates SSH connection arguments
|
||||
func validateSSHArgs(args []js.Value) (host string, port int, username string, err js.Value) {
|
||||
if len(args) < 2 {
|
||||
return "", 0, "", js.ValueOf("error: requires host and port")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return "", 0, "", js.ValueOf("host parameter must be a string")
|
||||
}
|
||||
if args[1].Type() != js.TypeNumber {
|
||||
return "", 0, "", js.ValueOf("port parameter must be a number")
|
||||
}
|
||||
|
||||
host = args[0].String()
|
||||
port = args[1].Int()
|
||||
username = "root"
|
||||
|
||||
if len(args) > 2 {
|
||||
if args[2].Type() == js.TypeString && args[2].String() != "" {
|
||||
username = args[2].String()
|
||||
} else if args[2].Type() != js.TypeString {
|
||||
return "", 0, "", js.ValueOf("username parameter must be a string")
|
||||
}
|
||||
}
|
||||
|
||||
return host, port, username, js.Undefined()
|
||||
}
|
||||
|
||||
// createSSHMethod creates the SSH connection method
|
||||
func createSSHMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: requires host and port")
|
||||
}
|
||||
|
||||
host := args[0].String()
|
||||
port := args[1].Int()
|
||||
username := "root"
|
||||
if len(args) > 2 && args[2].String() != "" {
|
||||
username = args[2].String()
|
||||
host, port, username, validationErr := validateSSHArgs(args)
|
||||
if !validationErr.IsUndefined() {
|
||||
if validationErr.Type() == js.TypeString && validationErr.String() == "error: requires host and port" {
|
||||
return validationErr
|
||||
}
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(validationErr)
|
||||
})
|
||||
}
|
||||
|
||||
var jwtToken string
|
||||
@@ -133,6 +171,9 @@ func createSSHMethod(client *netbird.Client) js.Func {
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
// Note: Don't wait for peer connection here - lazy connection manager
|
||||
// will open the connection when Dial is called in ssh.Connect().
|
||||
|
||||
sshClient := ssh.NewClient(client)
|
||||
|
||||
if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
|
||||
@@ -154,6 +195,110 @@ func createSSHMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
func performPing(client *netbird.Client, hostname string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
start := time.Now()
|
||||
conn, err := client.Dial(ctx, "ping", hostname)
|
||||
if err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err))
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close ping connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
icmpData := make([]byte, 8)
|
||||
icmpData[0] = icmpEchoRequest
|
||||
icmpData[1] = icmpCodeEcho
|
||||
|
||||
if _, err := conn.Write(icmpData); err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s write failed: %v", hostname, err))
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, pingBufferSize)
|
||||
if _, err := conn.Read(buf); err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s read failed: %v", hostname, err))
|
||||
return
|
||||
}
|
||||
|
||||
latency := time.Since(start)
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds()))
|
||||
}
|
||||
|
||||
func performPingTCP(client *netbird.Client, hostname string, port int) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
|
||||
defer cancel()
|
||||
|
||||
address := fmt.Sprintf("%s:%d", hostname, port)
|
||||
start := time.Now()
|
||||
conn, err := client.Dial(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err))
|
||||
return
|
||||
}
|
||||
latency := time.Since(start)
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close TCP connection: %v", err)
|
||||
}
|
||||
|
||||
js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds()))
|
||||
}
|
||||
|
||||
// createPingMethod creates the ping method
|
||||
func createPingMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: hostname required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
hostname := args[0].String()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
performPing(client, hostname)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createPingTCPMethod creates the pingtcp method
|
||||
func createPingTCPMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: hostname and port required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
if args[1].Type() != js.TypeNumber {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("port parameter must be a number"))
|
||||
})
|
||||
}
|
||||
|
||||
hostname := args[0].String()
|
||||
port := args[1].Int()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
performPingTCP(client, hostname, port)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createProxyRequestMethod creates the proxyRequest method
|
||||
func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
@@ -162,6 +307,11 @@ func createProxyRequestMethod(client *netbird.Client) js.Func {
|
||||
}
|
||||
|
||||
request := args[0]
|
||||
if request.Type() != js.TypeObject {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("request parameter must be an object"))
|
||||
})
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
response, err := http.ProxyRequest(client, request)
|
||||
@@ -181,23 +331,255 @@ func createRDPProxyMethod(client *netbird.Client) js.Func {
|
||||
return js.ValueOf("error: hostname and port required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("hostname parameter must be a string"))
|
||||
})
|
||||
}
|
||||
if args[1].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("port parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
proxy := rdp.NewRDCleanPathProxy(client)
|
||||
return proxy.CreateProxy(args[0].String(), args[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
// getStatusOverview is a helper to get the status overview
|
||||
func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) {
|
||||
fullStatus, err := client.Status()
|
||||
if err != nil {
|
||||
return nbstatus.OutputOverview{}, err
|
||||
}
|
||||
|
||||
pbFullStatus := fullStatus.ToProto()
|
||||
statusResp := &proto.StatusResponse{
|
||||
DaemonVersion: version.NetbirdVersion(),
|
||||
FullStatus: pbFullStatus,
|
||||
}
|
||||
|
||||
return nbstatus.ConvertToStatusOutputOverview(statusResp, false, "", nil, nil, nil, "", ""), nil
|
||||
}
|
||||
|
||||
// createStatusMethod creates the status method that returns JSON
|
||||
func createStatusMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
jsonStr, err := overview.JSON()
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
jsonObj := js.Global().Get("JSON").Call("parse", jsonStr)
|
||||
resolve.Invoke(jsonObj)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createStatusSummaryMethod creates the statusSummary method
|
||||
func createStatusSummaryMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
summary := overview.GeneralSummary(false, false, false, false)
|
||||
js.Global().Get("console").Call("log", summary)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createStatusDetailMethod creates the statusDetail method
|
||||
func createStatusDetailMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
log.Info("statusDetail called")
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
log.Info("statusDetail: getting overview")
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
log.Errorf("statusDetail: getStatusOverview failed: %v", err)
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("statusDetail: generating full detail summary")
|
||||
detail := overview.FullDetailSummary()
|
||||
log.Infof("statusDetail: detail length=%d", len(detail))
|
||||
js.Global().Get("console").Call("log", detail)
|
||||
resolve.Invoke(js.Undefined())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createWaitForPeerConnectionMethod creates a method that waits for a peer to be connected
|
||||
func createWaitForPeerConnectionMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
if len(args) < 1 {
|
||||
reject.Invoke(js.ValueOf("peer IP address required"))
|
||||
return
|
||||
}
|
||||
|
||||
peerIP := args[0].String()
|
||||
timeoutMs := int(defaultPeerConnectionTimeout.Milliseconds())
|
||||
if len(args) > 1 && !args[1].IsUndefined() && !args[1].IsNull() {
|
||||
timeoutMs = args[1].Int()
|
||||
}
|
||||
|
||||
timeout := time.Duration(timeoutMs) * time.Millisecond
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
log.Infof("Waiting for peer %s to be connected (timeout: %v)", peerIP, timeout)
|
||||
|
||||
ticker := time.NewTicker(peerConnectionPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("timeout waiting for peer %s to connect", peerIP)))
|
||||
return
|
||||
case <-ticker.C:
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
log.Debugf("Error getting status: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, peer := range overview.Peers.Details {
|
||||
if peer.IP == peerIP && peer.Status == "Connected" {
|
||||
log.Infof("Peer %s is now connected (type: %s)", peerIP, peer.ConnType)
|
||||
resolve.Invoke(js.ValueOf(map[string]interface{}{
|
||||
"ip": peer.IP,
|
||||
"status": peer.Status,
|
||||
"connType": peer.ConnType,
|
||||
}))
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Tracef("Peer %s not yet connected, waiting...", peerIP)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON
|
||||
func createGetSyncResponseMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
syncResp, err := client.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
options := protojson.MarshalOptions{
|
||||
EmitUnpopulated: true,
|
||||
UseProtoNames: true,
|
||||
AllowPartial: true,
|
||||
}
|
||||
jsonBytes, err := options.Marshal(syncResp)
|
||||
if err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("marshal sync response: %v", err)))
|
||||
return
|
||||
}
|
||||
|
||||
jsonObj := js.Global().Get("JSON").Call("parse", string(jsonBytes))
|
||||
resolve.Invoke(jsonObj)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createSetLogLevelMethod creates the setLogLevel method to dynamically change logging level
|
||||
func createSetLogLevelMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) < 1 {
|
||||
return js.ValueOf("error: log level required")
|
||||
}
|
||||
|
||||
if args[0].Type() != js.TypeString {
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
reject.Invoke(js.ValueOf("log level parameter must be a string"))
|
||||
})
|
||||
}
|
||||
|
||||
logLevel := args[0].String()
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
if err := client.SetLogLevel(logLevel); err != nil {
|
||||
reject.Invoke(js.ValueOf(fmt.Sprintf("set log level: %v", err)))
|
||||
return
|
||||
}
|
||||
log.Infof("Log level set to: %s", logLevel)
|
||||
resolve.Invoke(js.ValueOf(true))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// createPromise is a helper to create JavaScript promises
|
||||
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
resolve := promiseArgs[0]
|
||||
reject := promiseArgs[1]
|
||||
|
||||
go handler(resolve, reject)
|
||||
// Wrap reject to always log the error
|
||||
loggingReject := js.FuncOf(func(_ js.Value, args []js.Value) any {
|
||||
if len(args) > 0 {
|
||||
log.Errorf("Promise rejected: %v", args[0])
|
||||
js.Global().Get("console").Call("error", "WASM Promise rejected:", args[0])
|
||||
}
|
||||
reject.Invoke(args[0])
|
||||
return nil
|
||||
})
|
||||
|
||||
go handler(resolve, loggingReject.Value)
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// waitForPeerConnected waits for a peer with the given IP to be connected
|
||||
func waitForPeerConnected(ctx context.Context, client *netbird.Client, peerIP string) error {
|
||||
log.Infof("Waiting for peer %s to be connected before operation", peerIP)
|
||||
|
||||
ticker := time.NewTicker(peerConnectionPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("timeout waiting for peer %s to connect", peerIP)
|
||||
case <-ticker.C:
|
||||
overview, err := getStatusOverview(client)
|
||||
if err != nil {
|
||||
log.Debugf("Error getting status while waiting for peer: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, peer := range overview.Peers.Details {
|
||||
if peer.IP == peerIP && peer.Status == "Connected" {
|
||||
log.Infof("Peer %s is now connected (type: %s), proceeding with operation", peerIP, peer.ConnType)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
log.Tracef("Peer %s not yet connected, waiting...", peerIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// createDetectSSHServerMethod creates the SSH server detection method
|
||||
func createDetectSSHServerMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
@@ -220,6 +602,10 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Note: Don't wait for peer connection here - lazy connection manager
|
||||
// will open the connection when Dial is called. Waiting would cause
|
||||
// a deadlock since lazy connections only open on traffic.
|
||||
|
||||
serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port)
|
||||
if err != nil {
|
||||
reject.Invoke(err.Error())
|
||||
@@ -237,17 +623,25 @@ func createClientObject(client *netbird.Client) js.Value {
|
||||
|
||||
obj["start"] = createStartMethod(client)
|
||||
obj["stop"] = createStopMethod(client)
|
||||
obj["ping"] = createPingMethod(client)
|
||||
obj["pingtcp"] = createPingTCPMethod(client)
|
||||
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
obj["status"] = createStatusMethod(client)
|
||||
obj["statusSummary"] = createStatusSummaryMethod(client)
|
||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||
obj["waitForPeerConnection"] = createWaitForPeerConnectionMethod(client)
|
||||
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
|
||||
obj["setLogLevel"] = createSetLogLevelMethod(client)
|
||||
|
||||
return js.ValueOf(obj)
|
||||
}
|
||||
|
||||
// netBirdClientConstructor acts as a JavaScript constructor function
|
||||
func netBirdClientConstructor(this js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any {
|
||||
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
|
||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||
resolve := promiseArgs[0]
|
||||
reject := promiseArgs[1]
|
||||
|
||||
@@ -263,7 +657,10 @@ func netBirdClientConstructor(this js.Value, args []js.Value) any {
|
||||
return
|
||||
}
|
||||
|
||||
if err := util.InitLog(options.LogLevel, util.LogConsole); err != nil {
|
||||
// Force trace logging for debugging - must be set BEFORE netbird.New
|
||||
// as New() will call logrus.SetLevel with options.LogLevel
|
||||
options.LogLevel = "trace"
|
||||
if err := util.InitLog("trace", util.LogConsole); err != nil {
|
||||
log.Warnf("Failed to initialize logging: %v", err)
|
||||
}
|
||||
|
||||
|
||||
8
go.mod
8
go.mod
@@ -78,8 +78,8 @@ require (
|
||||
github.com/pion/logging v0.2.4
|
||||
github.com/pion/randutil v0.1.0
|
||||
github.com/pion/stun/v2 v2.0.0
|
||||
github.com/pion/stun/v3 v3.0.0
|
||||
github.com/pion/transport/v3 v3.0.7
|
||||
github.com/pion/stun/v3 v3.1.0
|
||||
github.com/pion/transport/v3 v3.1.1
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/pkg/sftp v1.13.9
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
@@ -241,7 +241,7 @@ require (
|
||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||
github.com/opencontainers/image-spec v1.1.0 // indirect
|
||||
github.com/pion/dtls/v2 v2.2.10 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.7 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.9 // indirect
|
||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||
github.com/pion/transport/v2 v2.2.4 // indirect
|
||||
github.com/pion/turn/v4 v4.1.1 // indirect
|
||||
@@ -263,7 +263,7 @@ require (
|
||||
github.com/tklauser/numcpus v0.8.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
github.com/wlynxg/anet v0.0.3 // indirect
|
||||
github.com/wlynxg/anet v0.0.5 // indirect
|
||||
github.com/yuin/goldmark v1.7.8 // indirect
|
||||
github.com/zeebo/blake3 v0.2.3 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
|
||||
16
go.sum
16
go.sum
@@ -444,8 +444,8 @@ github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7c
|
||||
github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
|
||||
github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
|
||||
github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
|
||||
github.com/pion/dtls/v3 v3.0.7 h1:bItXtTYYhZwkPFk4t1n3Kkf5TDrfj6+4wG+CZR8uI9Q=
|
||||
github.com/pion/dtls/v3 v3.0.7/go.mod h1:uDlH5VPrgOQIw59irKYkMudSFprY9IEFCqz/eTz16f8=
|
||||
github.com/pion/dtls/v3 v3.0.9 h1:4AijfFRm8mAjd1gfdlB1wzJF3fjjR/VPIpJgkEtvYmM=
|
||||
github.com/pion/dtls/v3 v3.0.9/go.mod h1:abApPjgadS/ra1wvUzHLc3o2HvoxppAh+NZkyApL4Os=
|
||||
github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
|
||||
github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8=
|
||||
github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so=
|
||||
@@ -455,14 +455,14 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA=
|
||||
github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8=
|
||||
github.com/pion/stun/v2 v2.0.0 h1:A5+wXKLAypxQri59+tmQKVs7+l6mMM+3d+eER9ifRU0=
|
||||
github.com/pion/stun/v2 v2.0.0/go.mod h1:22qRSh08fSEttYUmJZGlriq9+03jtVmXNODgLccj8GQ=
|
||||
github.com/pion/stun/v3 v3.0.0 h1:4h1gwhWLWuZWOJIJR9s2ferRO+W3zA/b6ijOI6mKzUw=
|
||||
github.com/pion/stun/v3 v3.0.0/go.mod h1:HvCN8txt8mwi4FBvS3EmDghW6aQJ24T+y+1TKjB5jyU=
|
||||
github.com/pion/stun/v3 v3.1.0 h1:bS1jjT3tGWZ4UPmIUeyalOylamTMTFg1OvXtY/r6seM=
|
||||
github.com/pion/stun/v3 v3.1.0/go.mod h1:egmx1CUcfSSGJxQCOjtVlomfPqmQ58BibPyuOWNGQEU=
|
||||
github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
|
||||
github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
|
||||
github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
|
||||
github.com/pion/transport/v3 v3.0.1/go.mod h1:UY7kiITrlMv7/IKgd5eTUcaahZx5oUN3l9SzK5f5xE0=
|
||||
github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0=
|
||||
github.com/pion/transport/v3 v3.0.7/go.mod h1:YleKiTZ4vqNxVwh77Z0zytYi7rXHl7j6uPLGhhz9rwo=
|
||||
github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM=
|
||||
github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ=
|
||||
github.com/pion/turn/v3 v3.0.1 h1:wLi7BTQr6/Q20R0vt/lHbjv6y4GChFtC33nkYbasoT8=
|
||||
github.com/pion/turn/v3 v3.0.1/go.mod h1:MrJDKgqryDyWy1/4NT9TWfXWGMC7UHT6pJIv1+gMeNE=
|
||||
github.com/pion/turn/v4 v4.1.1 h1:9UnY2HB99tpDyz3cVVZguSxcqkJ1DsTSZ+8TGruh4fc=
|
||||
@@ -574,8 +574,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
github.com/wlynxg/anet v0.0.3 h1:PvR53psxFXstc12jelG6f1Lv4MWqE0tI76/hHGjh9rg=
|
||||
github.com/wlynxg/anet v0.0.3/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU=
|
||||
github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
|
||||
|
||||
25
management/server/cache/idp.go
vendored
25
management/server/cache/idp.go
vendored
@@ -26,6 +26,8 @@ type UserDataCache interface {
|
||||
Get(ctx context.Context, key string) (*idp.UserData, error)
|
||||
Set(ctx context.Context, key string, value *idp.UserData, expiration time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
GetUsers(ctx context.Context, key string) ([]*idp.UserData, error)
|
||||
SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error
|
||||
}
|
||||
|
||||
// UserDataCacheImpl is a struct that implements the UserDataCache interface.
|
||||
@@ -51,6 +53,29 @@ func (u *UserDataCacheImpl) Delete(ctx context.Context, key string) error {
|
||||
return u.cache.Delete(ctx, key)
|
||||
}
|
||||
|
||||
func (u *UserDataCacheImpl) GetUsers(ctx context.Context, key string) ([]*idp.UserData, error) {
|
||||
var users []*idp.UserData
|
||||
v, err := u.cache.Get(ctx, key, &users)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch v := v.(type) {
|
||||
case []*idp.UserData:
|
||||
return v, nil
|
||||
case *[]*idp.UserData:
|
||||
return *v, nil
|
||||
case []byte:
|
||||
return unmarshalUserData(v)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unexpected type: %T", v)
|
||||
}
|
||||
|
||||
func (u *UserDataCacheImpl) SetUsers(ctx context.Context, key string, users []*idp.UserData, expiration time.Duration) error {
|
||||
return u.cache.Set(ctx, key, users, store.WithExpiration(expiration))
|
||||
}
|
||||
|
||||
// NewUserDataCache creates a new UserDataCacheImpl object.
|
||||
func NewUserDataCache(store store.StoreInterface) *UserDataCacheImpl {
|
||||
simpleCache := cache.New[any](store)
|
||||
|
||||
@@ -178,7 +178,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
m.patUsageTracker.IncrementUsage(token)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil {
|
||||
if m.rateLimiter != nil && !isTerraformRequest(r) {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
@@ -214,6 +214,11 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
||||
return nbcontext.SetUserAuthInRequest(r, userAuth), nil
|
||||
}
|
||||
|
||||
func isTerraformRequest(r *http.Request) bool {
|
||||
ua := strings.ToLower(r.Header.Get("User-Agent"))
|
||||
return strings.Contains(ua, "terraform")
|
||||
}
|
||||
|
||||
// getTokenFromJWTRequest is a "TokenExtractor" that takes auth header parts and extracts
|
||||
// the JWT token from the Authorization header.
|
||||
func getTokenFromJWTRequest(authHeaderParts []string) (string, error) {
|
||||
|
||||
@@ -508,6 +508,103 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
|
||||
})
|
||||
|
||||
t.Run("Terraform User Agent Not Rate Limited", func(t *testing.T) {
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Test various Terraform user agent formats
|
||||
terraformUserAgents := []string{
|
||||
"Terraform/1.5.0",
|
||||
"terraform/1.0.0",
|
||||
"Terraform-Provider/2.0.0",
|
||||
"Mozilla/5.0 (compatible; Terraform/1.3.0)",
|
||||
}
|
||||
|
||||
for _, userAgent := range terraformUserAgents {
|
||||
t.Run("UserAgent: "+userAgent, func(t *testing.T) {
|
||||
successCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 10, successCount, "All Terraform user agent requests should succeed (not rate limited)")
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Non-Terraform User Agent With PAT Is Rate Limited", func(t *testing.T) {
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
nil,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
req.Header.Set("User-Agent", "curl/7.68.0")
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
|
||||
@@ -911,10 +911,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
accountUsers := []*types.User{}
|
||||
switch {
|
||||
case allowed:
|
||||
start := time.Now()
|
||||
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.WithContext(ctx).Tracef("Got %d users from account %s after %s", len(accountUsers), accountID, time.Since(start))
|
||||
case user != nil && user.AccountID == accountID:
|
||||
accountUsers = append(accountUsers, user)
|
||||
default:
|
||||
@@ -933,23 +935,40 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
|
||||
if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) {
|
||||
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
||||
usersFromIntegration := make([]*idp.UserData, 0)
|
||||
filtered := make(map[string]*idp.UserData, len(accountUsers))
|
||||
log.WithContext(ctx).Tracef("Querying users from IDP for account %s", accountID)
|
||||
start := time.Now()
|
||||
|
||||
integrationKeys := make(map[string]struct{})
|
||||
for _, user := range accountUsers {
|
||||
if user.Issued == types.UserIssuedIntegration {
|
||||
key := user.IntegrationReference.CacheKey(accountID, user.Id)
|
||||
info, err := am.externalCacheManager.Get(am.ctx, key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Infof("Get ExternalCache for key: %s, error: %s", key, err)
|
||||
users[user.Id] = true
|
||||
continue
|
||||
}
|
||||
usersFromIntegration = append(usersFromIntegration, info)
|
||||
integrationKeys[user.IntegrationReference.CacheKey(accountID)] = struct{}{}
|
||||
continue
|
||||
}
|
||||
if !user.IsServiceUser {
|
||||
users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero())
|
||||
}
|
||||
}
|
||||
|
||||
for key := range integrationKeys {
|
||||
usersData, err := am.externalCacheManager.GetUsers(am.ctx, key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("GetUsers from ExternalCache for key: %s, error: %s", key, err)
|
||||
continue
|
||||
}
|
||||
for _, ud := range usersData {
|
||||
filtered[ud.ID] = ud
|
||||
}
|
||||
}
|
||||
|
||||
for _, ud := range filtered {
|
||||
usersFromIntegration = append(usersFromIntegration, ud)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("Got user info from external cache after %s", time.Since(start))
|
||||
start = time.Now()
|
||||
queriedUsers, err = am.lookupCache(ctx, users, accountID)
|
||||
log.WithContext(ctx).Tracef("Got user info from cache for %d users after %s", len(queriedUsers), time.Since(start))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1086,8 +1086,12 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
cacheManager := am.GetExternalCacheManager()
|
||||
cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
|
||||
err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}, time.Minute)
|
||||
tud := &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}
|
||||
cacheKeyUser := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id)
|
||||
err = cacheManager.Set(context.Background(), cacheKeyUser, tud, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
cacheKeyAccount := externalUser.IntegrationReference.CacheKey(mockAccountID)
|
||||
err = cacheManager.SetUsers(context.Background(), cacheKeyAccount, []*idp.UserData{tud}, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -22,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
"github.com/netbirdio/netbird/shared/relay/auth"
|
||||
"github.com/netbirdio/netbird/signal/metrics"
|
||||
"github.com/netbirdio/netbird/stun"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -43,6 +45,10 @@ type Config struct {
|
||||
LogLevel string
|
||||
LogFile string
|
||||
HealthcheckListenAddress string
|
||||
// STUN server configuration
|
||||
EnableSTUN bool
|
||||
STUNPorts []int
|
||||
STUNLogLevel string
|
||||
}
|
||||
|
||||
func (c Config) Validate() error {
|
||||
@@ -52,6 +58,25 @@ func (c Config) Validate() error {
|
||||
if c.AuthSecret == "" {
|
||||
return fmt.Errorf("auth secret is required")
|
||||
}
|
||||
|
||||
// Validate STUN configuration
|
||||
if c.EnableSTUN {
|
||||
if len(c.STUNPorts) == 0 {
|
||||
return fmt.Errorf("--stun-ports is required when --enable-stun is set")
|
||||
}
|
||||
|
||||
seen := make(map[int]bool)
|
||||
for _, port := range c.STUNPorts {
|
||||
if port <= 0 || port > 65535 {
|
||||
return fmt.Errorf("invalid STUN port %d: must be between 1 and 65535", port)
|
||||
}
|
||||
if seen[port] {
|
||||
return fmt.Errorf("duplicate STUN port %d", port)
|
||||
}
|
||||
seen[port] = true
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -91,6 +116,9 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
|
||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
|
||||
rootCmd.PersistentFlags().StringVarP(&cobraConfig.HealthcheckListenAddress, "health-listen-address", "H", ":9000", "listen address of healthcheck server")
|
||||
rootCmd.PersistentFlags().BoolVar(&cobraConfig.EnableSTUN, "enable-stun", false, "enable embedded STUN server")
|
||||
rootCmd.PersistentFlags().IntSliceVar(&cobraConfig.STUNPorts, "stun-ports", []int{3478}, "ports for the embedded STUN server (can be specified multiple times or comma-separated)")
|
||||
rootCmd.PersistentFlags().StringVar(&cobraConfig.STUNLogLevel, "stun-log-level", "info", "log level for STUN server (panic, fatal, error, warn, info, debug, trace)")
|
||||
|
||||
setFlagsFromEnvVars(rootCmd)
|
||||
}
|
||||
@@ -119,21 +147,14 @@ func execute(cmd *cobra.Command, args []string) error {
|
||||
return fmt.Errorf("failed to initialize log: %s", err)
|
||||
}
|
||||
|
||||
// Resource creation phase (fail fast before starting any goroutines)
|
||||
|
||||
metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
|
||||
if err != nil {
|
||||
log.Debugf("setup metrics: %v", err)
|
||||
return fmt.Errorf("setup metrics: %v", err)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("Failed to start metrics server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
srvListenerCfg := server.ListenerConfig{
|
||||
Address: cobraConfig.ListenAddress,
|
||||
}
|
||||
@@ -145,6 +166,12 @@ func execute(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
srvListenerCfg.TLSConfig = tlsConfig
|
||||
|
||||
// Create STUN listeners early to fail fast
|
||||
stunListeners, err := createSTUNListeners()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
|
||||
authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
|
||||
|
||||
@@ -155,60 +182,145 @@ func execute(cmd *cobra.Command, args []string) error {
|
||||
TLSSupport: tlsSupport,
|
||||
}
|
||||
|
||||
srv, err := server.NewServer(cfg)
|
||||
srv, err := createRelayServer(cfg)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create relay server: %v", err)
|
||||
return fmt.Errorf("failed to create relay server: %v", err)
|
||||
cleanupSTUNListeners(stunListeners)
|
||||
return err
|
||||
}
|
||||
|
||||
hCfg := healthcheck.Config{
|
||||
ListenAddress: cobraConfig.HealthcheckListenAddress,
|
||||
ServiceChecker: srv,
|
||||
}
|
||||
httpHealthcheck, err := createHealthCheck(hCfg)
|
||||
if err != nil {
|
||||
cleanupSTUNListeners(stunListeners)
|
||||
return err
|
||||
}
|
||||
|
||||
var stunServer *stun.Server
|
||||
if len(stunListeners) > 0 {
|
||||
stunServer = stun.NewServer(stunListeners, cobraConfig.STUNLogLevel)
|
||||
}
|
||||
|
||||
// Start all servers (only after all resources are successfully created)
|
||||
startServers(&wg, metricsServer, srv, srvListenerCfg, httpHealthcheck, stunServer)
|
||||
|
||||
waitForExitSignal()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err = shutdownServers(ctx, metricsServer, srv, httpHealthcheck, stunServer)
|
||||
wg.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func startServers(wg *sync.WaitGroup, metricsServer *metrics.Metrics, srv *server.Server, srvListenerCfg server.ListenerConfig, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
|
||||
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("failed to start metrics server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceURL := srv.InstanceURL()
|
||||
log.Infof("server will be available on: %s", instanceURL.String())
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := srv.Listen(srvListenerCfg); err != nil {
|
||||
log.Fatalf("failed to bind server: %s", err)
|
||||
log.Fatalf("failed to bind relay server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
hCfg := healthcheck.Config{
|
||||
ListenAddress: cobraConfig.HealthcheckListenAddress,
|
||||
ServiceChecker: srv,
|
||||
}
|
||||
httpHealthcheck, err := healthcheck.NewServer(hCfg)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create healthcheck server: %v", err)
|
||||
return fmt.Errorf("failed to create healthcheck server: %v", err)
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
log.Fatalf("Failed to start healthcheck server: %v", err)
|
||||
log.Fatalf("failed to start healthcheck server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// it will block until exit signal
|
||||
waitForExitSignal()
|
||||
if stunServer != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := stunServer.Listen(); err != nil {
|
||||
if errors.Is(err, stun.ErrServerClosed) {
|
||||
return
|
||||
}
|
||||
log.Errorf("STUN server error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
func shutdownServers(ctx context.Context, metricsServer *metrics.Metrics, srv *server.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server) error {
|
||||
var errs error
|
||||
|
||||
var shutDownErrors error
|
||||
if err := httpHealthcheck.Shutdown(ctx); err != nil {
|
||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close healthcheck server: %v", err))
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close healthcheck server: %w", err))
|
||||
}
|
||||
|
||||
if stunServer != nil {
|
||||
if err := stunServer.Shutdown(); err != nil {
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close STUN server: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close relay server: %w", err))
|
||||
}
|
||||
|
||||
log.Infof("shutting down metrics server")
|
||||
if err := metricsServer.Shutdown(ctx); err != nil {
|
||||
shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
|
||||
errs = multierror.Append(errs, fmt.Errorf("failed to close metrics server: %w", err))
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return shutDownErrors
|
||||
return errs
|
||||
}
|
||||
|
||||
func createHealthCheck(hCfg healthcheck.Config) (*healthcheck.Server, error) {
|
||||
httpHealthcheck, err := healthcheck.NewServer(hCfg)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create healthcheck server: %v", err)
|
||||
return nil, fmt.Errorf("failed to create healthcheck server: %v", err)
|
||||
}
|
||||
return httpHealthcheck, nil
|
||||
}
|
||||
|
||||
func createRelayServer(cfg server.Config) (*server.Server, error) {
|
||||
srv, err := server.NewServer(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create relay server: %v", err)
|
||||
}
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
func cleanupSTUNListeners(stunListeners []*net.UDPConn) {
|
||||
for _, l := range stunListeners {
|
||||
_ = l.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func createSTUNListeners() ([]*net.UDPConn, error) {
|
||||
var stunListeners []*net.UDPConn
|
||||
if cobraConfig.EnableSTUN {
|
||||
for _, port := range cobraConfig.STUNPorts {
|
||||
listener, err := net.ListenUDP("udp", &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
// Close already opened listeners on failure
|
||||
cleanupSTUNListeners(stunListeners)
|
||||
log.Debugf("failed to create STUN listener on port %d: %v", port, err)
|
||||
return nil, fmt.Errorf("failed to create STUN listener on port %d: %v", port, err)
|
||||
}
|
||||
stunListeners = append(stunListeners, listener)
|
||||
}
|
||||
}
|
||||
return stunListeners, nil
|
||||
}
|
||||
|
||||
func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {
|
||||
|
||||
@@ -8,6 +8,10 @@ pid="$(pgrep -x -f /usr/bin/netbird-ui || true)"
|
||||
if [ -n "${pid}" ]
|
||||
then
|
||||
uid="$(cat /proc/"${pid}"/loginuid)"
|
||||
# loginuid can be 4294967295 (-1) if not set, fall back to process uid
|
||||
if [ "${uid}" = "4294967295" ] || [ "${uid}" = "-1" ]; then
|
||||
uid="$(stat -c '%u' /proc/"${pid}")"
|
||||
fi
|
||||
username="$(id -nu "${uid}")"
|
||||
# Only re-run if it was already running
|
||||
pkill -x -f /usr/bin/netbird-ui >/dev/null 2>&1
|
||||
|
||||
170
stun/server.go
Normal file
170
stun/server.go
Normal file
@@ -0,0 +1,170 @@
|
||||
// Package stun provides an embedded STUN server for NAT traversal discovery.
|
||||
package stun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/pion/stun/v3"
|
||||
)
|
||||
|
||||
// ErrServerClosed is returned by Listen when the server is shut down gracefully.
|
||||
var ErrServerClosed = errors.New("stun: server closed")
|
||||
|
||||
// ErrNoListeners is returned by Listen when no UDP connections were provided.
|
||||
var ErrNoListeners = errors.New("stun: no listeners configured")
|
||||
|
||||
// Server implements a STUN server that responds to binding requests
|
||||
// with the client's reflexive transport address.
|
||||
type Server struct {
|
||||
conns []*net.UDPConn
|
||||
logger *log.Entry
|
||||
logLevel log.Level
|
||||
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewServer creates a new STUN server with the given UDP listeners.
|
||||
// The caller is responsible for creating and providing the listeners.
|
||||
// logLevel can be: panic, fatal, error, warn, info, debug, trace
|
||||
func NewServer(conns []*net.UDPConn, logLevel string) *Server {
|
||||
level, err := log.ParseLevel(logLevel)
|
||||
if err != nil {
|
||||
level = log.InfoLevel
|
||||
}
|
||||
|
||||
// Create a separate logger with its own level setting
|
||||
// This allows --stun-log-level to work independently of --log-level
|
||||
stunLogger := log.New()
|
||||
stunLogger.SetOutput(log.StandardLogger().Out)
|
||||
stunLogger.SetLevel(level)
|
||||
// Use the formatter package to set up formatter, ReportCaller, and context hook
|
||||
formatter.SetTextFormatter(stunLogger)
|
||||
|
||||
logger := stunLogger.WithField("component", "stun-server")
|
||||
logger.Infof("STUN server log level set to: %s", level.String())
|
||||
|
||||
return &Server{
|
||||
conns: conns,
|
||||
logger: logger,
|
||||
logLevel: level,
|
||||
}
|
||||
}
|
||||
|
||||
// Listen starts the STUN server and blocks until the server is shut down.
|
||||
// Returns ErrServerClosed when shut down gracefully via Shutdown.
|
||||
// Returns ErrNoListeners if no UDP connections were provided.
|
||||
func (s *Server) Listen() error {
|
||||
if len(s.conns) == 0 {
|
||||
return ErrNoListeners
|
||||
}
|
||||
|
||||
// Start a read loop for each listener
|
||||
for _, conn := range s.conns {
|
||||
s.logger.Infof("STUN server listening on %s", conn.LocalAddr())
|
||||
s.wg.Add(1)
|
||||
go s.readLoop(conn)
|
||||
}
|
||||
|
||||
s.wg.Wait()
|
||||
return ErrServerClosed
|
||||
}
|
||||
|
||||
// readLoop continuously reads UDP packets and handles STUN requests.
|
||||
func (s *Server) readLoop(conn *net.UDPConn) {
|
||||
defer s.wg.Done()
|
||||
buf := make([]byte, 1500) // Standard MTU size
|
||||
for {
|
||||
n, remoteAddr, err := conn.ReadFromUDP(buf)
|
||||
|
||||
if err != nil {
|
||||
// Check if the connection was closed externally
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
s.logger.Info("UDP connection closed, stopping read loop")
|
||||
return
|
||||
}
|
||||
s.logger.Warnf("failed to read UDP packet: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle packet in the same goroutine to avoid complexity
|
||||
// STUN responses are small and fast
|
||||
s.handlePacket(conn, buf[:n], remoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// handlePacket processes a STUN request and sends a response.
|
||||
func (s *Server) handlePacket(conn *net.UDPConn, data []byte, addr *net.UDPAddr) {
|
||||
localPort := conn.LocalAddr().(*net.UDPAddr).Port
|
||||
|
||||
s.logger.Debugf("[port:%d] received %d bytes from %s", localPort, len(data), addr)
|
||||
|
||||
// Check if it's a STUN message
|
||||
if !stun.IsMessage(data) {
|
||||
s.logger.Debugf("[port:%d] not a STUN message (first bytes: %x)", localPort, data[:min(len(data), 8)])
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the STUN message
|
||||
msg := &stun.Message{Raw: data}
|
||||
if err := msg.Decode(); err != nil {
|
||||
s.logger.Warnf("[port:%d] failed to decode STUN message from %s: %v", localPort, addr, err)
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Debugf("[port:%d] received STUN %s from %s (tx=%x)", localPort, msg.Type, addr, msg.TransactionID[:8])
|
||||
|
||||
// Only handle binding requests
|
||||
if msg.Type != stun.BindingRequest {
|
||||
s.logger.Debugf("[port:%d] ignoring non-binding request: %s", localPort, msg.Type)
|
||||
return
|
||||
}
|
||||
|
||||
// Build the response
|
||||
response, err := stun.Build(
|
||||
stun.NewTransactionIDSetter(msg.TransactionID),
|
||||
stun.BindingSuccess,
|
||||
&stun.XORMappedAddress{
|
||||
IP: addr.IP,
|
||||
Port: addr.Port,
|
||||
},
|
||||
stun.Fingerprint,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Errorf("[port:%d] failed to build STUN response: %v", localPort, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Send the response on the same connection it was received on
|
||||
n, err := conn.WriteToUDP(response.Raw, addr)
|
||||
if err != nil {
|
||||
s.logger.Errorf("[port:%d] failed to send STUN response to %s: %v", localPort, addr, err)
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Debugf("[port:%d] sent STUN BindingSuccess to %s (%d bytes) with XORMappedAddress %s:%d", localPort, addr, n, addr.IP, addr.Port)
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the STUN server.
|
||||
func (s *Server) Shutdown() error {
|
||||
s.logger.Info("shutting down STUN server")
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, conn := range s.conns {
|
||||
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
merr = multierror.Append(merr, fmt.Errorf("close STUN UDP connection: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all readLoops to finish
|
||||
s.wg.Wait()
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
479
stun/server_test.go
Normal file
479
stun/server_test.go
Normal file
@@ -0,0 +1,479 @@
|
||||
package stun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pion/stun/v3"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// createTestServer creates a STUN server listening on a random port for testing.
|
||||
// Returns the server, the listener connection (caller must close), and the server address.
|
||||
func createTestServer(t testing.TB) (*Server, *net.UDPConn, *net.UDPAddr) {
|
||||
t.Helper()
|
||||
conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||
require.NoError(t, err)
|
||||
server := NewServer([]*net.UDPConn{conn}, "debug")
|
||||
return server, conn, conn.LocalAddr().(*net.UDPAddr)
|
||||
}
|
||||
|
||||
// waitForServerReady polls the server with STUN binding requests until it responds.
|
||||
// This avoids flaky tests on slow CI machines that relied on time.Sleep.
|
||||
func waitForServerReady(t testing.TB, serverAddr *net.UDPAddr, timeout time.Duration) {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
retryInterval := 10 * time.Millisecond
|
||||
|
||||
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
for time.Now().Before(deadline) {
|
||||
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = clientConn.Write(msg.Raw)
|
||||
require.NoError(t, err)
|
||||
|
||||
_ = clientConn.SetReadDeadline(time.Now().Add(retryInterval))
|
||||
n, err := clientConn.Read(buf)
|
||||
if err != nil {
|
||||
// Timeout or other error, retry
|
||||
continue
|
||||
}
|
||||
|
||||
response := &stun.Message{Raw: buf[:n]}
|
||||
if err := response.Decode(); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if response.Type == stun.BindingSuccess {
|
||||
return // Server is ready
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatalf("server did not become ready within %v", timeout)
|
||||
}
|
||||
|
||||
func TestServer_BindingRequest(t *testing.T) {
|
||||
// Start the STUN server on a random port
|
||||
server, listener, serverAddr := createTestServer(t)
|
||||
|
||||
// Start server in background
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
serverErrCh <- server.Listen()
|
||||
}()
|
||||
|
||||
// Wait for server to be ready
|
||||
waitForServerReady(t, serverAddr, 2*time.Second)
|
||||
|
||||
// Create a UDP client
|
||||
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
// Build a STUN binding request
|
||||
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Send the request
|
||||
_, err = clientConn.Write(msg.Raw)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read the response
|
||||
buf := make([]byte, 1500)
|
||||
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, err := clientConn.Read(buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse the response
|
||||
response := &stun.Message{Raw: buf[:n]}
|
||||
err = response.Decode()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify it's a binding success
|
||||
assert.Equal(t, stun.BindingSuccess, response.Type)
|
||||
|
||||
// Extract the XOR-MAPPED-ADDRESS
|
||||
var xorAddr stun.XORMappedAddress
|
||||
err = xorAddr.GetFrom(response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the address matches our client's local address
|
||||
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
|
||||
assert.Equal(t, clientAddr.IP.String(), xorAddr.IP.String())
|
||||
assert.Equal(t, clientAddr.Port, xorAddr.Port)
|
||||
|
||||
// Close listener first to unblock readLoop, then shutdown
|
||||
_ = listener.Close()
|
||||
err = server.Shutdown()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestServer_IgnoresNonSTUNPackets(t *testing.T) {
|
||||
server, listener, serverAddr := createTestServer(t)
|
||||
|
||||
go func() {
|
||||
_ = server.Listen()
|
||||
}()
|
||||
|
||||
waitForServerReady(t, serverAddr, 2*time.Second)
|
||||
|
||||
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
// Send non-STUN data
|
||||
_, err = clientConn.Write([]byte("hello world"))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to read response (should timeout since server ignores non-STUN)
|
||||
buf := make([]byte, 1500)
|
||||
_ = clientConn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||
_, err = clientConn.Read(buf)
|
||||
assert.Error(t, err) // Should be a timeout error
|
||||
|
||||
// Close listener first to unblock readLoop, then shutdown
|
||||
_ = listener.Close()
|
||||
_ = server.Shutdown()
|
||||
}
|
||||
|
||||
func TestServer_Shutdown(t *testing.T) {
|
||||
server, listener, serverAddr := createTestServer(t)
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
go func() {
|
||||
err := server.Listen()
|
||||
assert.True(t, errors.Is(err, ErrServerClosed))
|
||||
close(serverDone)
|
||||
}()
|
||||
|
||||
waitForServerReady(t, serverAddr, 2*time.Second)
|
||||
|
||||
// Close listener first to unblock readLoop, then shutdown
|
||||
_ = listener.Close()
|
||||
|
||||
err := server.Shutdown()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for Listen to return
|
||||
select {
|
||||
case <-serverDone:
|
||||
// Success
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("server did not shutdown in time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_MultipleRequests(t *testing.T) {
|
||||
server, listener, serverAddr := createTestServer(t)
|
||||
|
||||
go func() {
|
||||
_ = server.Listen()
|
||||
}()
|
||||
|
||||
waitForServerReady(t, serverAddr, 2*time.Second)
|
||||
|
||||
// Create multiple clients and send requests
|
||||
for i := 0; i < 5; i++ {
|
||||
func() {
|
||||
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = clientConn.Write(msg.Raw)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, err := clientConn.Read(buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
response := &stun.Message{Raw: buf[:n]}
|
||||
err = response.Decode()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, stun.BindingSuccess, response.Type)
|
||||
}()
|
||||
}
|
||||
|
||||
// Close listener first to unblock readLoop, then shutdown
|
||||
_ = listener.Close()
|
||||
_ = server.Shutdown()
|
||||
}
|
||||
|
||||
func TestServer_ConcurrentClients(t *testing.T) {
|
||||
numClients := 100
|
||||
requestsPerClient := 5
|
||||
maxStartDelay := 100 * time.Millisecond // Random delay before client starts
|
||||
maxRequestDelay := 500 * time.Millisecond // Random delay between requests
|
||||
|
||||
// Remote server to test against via env var STUN_TEST_SERVER
|
||||
// Example: STUN_TEST_SERVER=example.netbird.io:3478 go test -v ./stun/... -run ConcurrentClients
|
||||
remoteServer := os.Getenv("STUN_TEST_SERVER")
|
||||
|
||||
var serverAddr *net.UDPAddr
|
||||
var server *Server
|
||||
var listener *net.UDPConn
|
||||
|
||||
if remoteServer != "" {
|
||||
// Use remote server
|
||||
var err error
|
||||
serverAddr, err = net.ResolveUDPAddr("udp", remoteServer)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Testing against remote server: %s", remoteServer)
|
||||
} else {
|
||||
// Start local server
|
||||
server, listener, serverAddr = createTestServer(t)
|
||||
go func() {
|
||||
_ = server.Listen()
|
||||
}()
|
||||
waitForServerReady(t, serverAddr, 2*time.Second)
|
||||
t.Logf("Testing against local server: %s", serverAddr)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errorz := make(chan error, numClients*requestsPerClient)
|
||||
successCount := make(chan int, numClients)
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
for i := 0; i < numClients; i++ {
|
||||
wg.Add(1)
|
||||
go func(clientID int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Random delay before starting
|
||||
time.Sleep(time.Duration(rand.Int63n(int64(maxStartDelay))))
|
||||
|
||||
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||
if err != nil {
|
||||
errorz <- fmt.Errorf("client %d: failed to dial: %w", clientID, err)
|
||||
return
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
success := 0
|
||||
for j := 0; j < requestsPerClient; j++ {
|
||||
// Random delay between requests
|
||||
if j > 0 {
|
||||
time.Sleep(time.Duration(rand.Int63n(int64(maxRequestDelay))))
|
||||
}
|
||||
|
||||
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
||||
if err != nil {
|
||||
errorz <- fmt.Errorf("client %d: failed to build request: %w", clientID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = clientConn.Write(msg.Raw)
|
||||
if err != nil {
|
||||
errorz <- fmt.Errorf("client %d: failed to write: %w", clientID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
_ = clientConn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
n, err := clientConn.Read(buf)
|
||||
if err != nil {
|
||||
errorz <- fmt.Errorf("client %d: failed to read: %w", clientID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
response := &stun.Message{Raw: buf[:n]}
|
||||
if err := response.Decode(); err != nil {
|
||||
errorz <- fmt.Errorf("client %d: failed to decode: %w", clientID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if response.Type != stun.BindingSuccess {
|
||||
errorz <- fmt.Errorf("client %d: unexpected response type: %s", clientID, response.Type)
|
||||
continue
|
||||
}
|
||||
|
||||
success++
|
||||
}
|
||||
successCount <- success
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errorz)
|
||||
close(successCount)
|
||||
|
||||
elapsed := time.Since(startTime)
|
||||
|
||||
totalSuccess := 0
|
||||
for count := range successCount {
|
||||
totalSuccess += count
|
||||
}
|
||||
|
||||
var errs []error
|
||||
for err := range errorz {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
totalRequests := numClients * requestsPerClient
|
||||
t.Logf("Completed %d/%d requests in %v (%.2f req/s)",
|
||||
totalSuccess, totalRequests, elapsed,
|
||||
float64(totalSuccess)/elapsed.Seconds())
|
||||
|
||||
if len(errs) > 0 {
|
||||
t.Logf("Errors (%d):", len(errs))
|
||||
for i, err := range errs {
|
||||
if i < 10 { // Only show first 10 errors
|
||||
t.Logf(" - %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Require at least 95% success rate
|
||||
successRate := float64(totalSuccess) / float64(totalRequests)
|
||||
require.GreaterOrEqual(t, successRate, 0.95, "success rate too low: %.2f%%", successRate*100)
|
||||
|
||||
// Cleanup local server if used
|
||||
if server != nil {
|
||||
// Close listener first to unblock readLoop, then shutdown
|
||||
_ = listener.Close()
|
||||
_ = server.Shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_MultiplePorts(t *testing.T) {
|
||||
// Create listeners on two random ports
|
||||
conn1, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||
require.NoError(t, err)
|
||||
conn2, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
|
||||
require.NoError(t, err)
|
||||
|
||||
addr1 := conn1.LocalAddr().(*net.UDPAddr)
|
||||
addr2 := conn2.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
server := NewServer([]*net.UDPConn{conn1, conn2}, "debug")
|
||||
|
||||
go func() {
|
||||
_ = server.Listen()
|
||||
}()
|
||||
|
||||
// Wait for server to be ready (checking first port is sufficient)
|
||||
waitForServerReady(t, addr1, 2*time.Second)
|
||||
|
||||
// Test requests on both ports
|
||||
for _, serverAddr := range []*net.UDPAddr{addr1, addr2} {
|
||||
func() {
|
||||
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
msg, err := stun.Build(stun.TransactionID, stun.BindingRequest)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = clientConn.Write(msg.Raw)
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, err := clientConn.Read(buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
response := &stun.Message{Raw: buf[:n]}
|
||||
err = response.Decode()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, stun.BindingSuccess, response.Type)
|
||||
|
||||
var xorAddr stun.XORMappedAddress
|
||||
err = xorAddr.GetFrom(response)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
|
||||
assert.Equal(t, clientAddr.Port, xorAddr.Port)
|
||||
}()
|
||||
}
|
||||
|
||||
// Close listeners first to unblock readLoops, then shutdown
|
||||
_ = conn1.Close()
|
||||
_ = conn2.Close()
|
||||
_ = server.Shutdown()
|
||||
}
|
||||
|
||||
// BenchmarkSTUNServer benchmarks the STUN server with concurrent clients
|
||||
func BenchmarkSTUNServer(b *testing.B) {
|
||||
server, listener, serverAddr := createTestServer(b)
|
||||
|
||||
go func() {
|
||||
_ = server.Listen()
|
||||
}()
|
||||
|
||||
waitForServerReady(b, serverAddr, 2*time.Second)
|
||||
|
||||
// Capture first error atomically - b.Fatal cannot be called from worker goroutines
|
||||
var firstErr atomic.Pointer[error]
|
||||
setErr := func(err error) {
|
||||
firstErr.CompareAndSwap(nil, &err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
// Stop work if an error has occurred
|
||||
if firstErr.Load() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, err := net.DialUDP("udp", nil, serverAddr)
|
||||
if err != nil {
|
||||
setErr(err)
|
||||
return
|
||||
}
|
||||
defer clientConn.Close()
|
||||
|
||||
buf := make([]byte, 1500)
|
||||
|
||||
for pb.Next() {
|
||||
if firstErr.Load() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
msg, _ := stun.Build(stun.TransactionID, stun.BindingRequest)
|
||||
_, _ = clientConn.Write(msg.Raw)
|
||||
|
||||
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
n, err := clientConn.Read(buf)
|
||||
if err != nil {
|
||||
setErr(err)
|
||||
return
|
||||
}
|
||||
|
||||
response := &stun.Message{Raw: buf[:n]}
|
||||
if err := response.Decode(); err != nil {
|
||||
setErr(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
b.StopTimer()
|
||||
|
||||
// Fail after RunParallel completes
|
||||
if errPtr := firstErr.Load(); errPtr != nil {
|
||||
b.Fatal(*errPtr)
|
||||
}
|
||||
|
||||
// Close listener first to unblock readLoop, then shutdown
|
||||
_ = listener.Close()
|
||||
_ = server.Shutdown()
|
||||
}
|
||||
Reference in New Issue
Block a user