diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 7ca56857b..e56f66103 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -314,9 +314,8 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string { profName = activeProf.Name } - statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), - ) + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName) + statusOutputString = overview.FullDetailSummary() } return statusOutputString } diff --git a/client/cmd/status.go b/client/cmd/status.go index 99d47cd1a..05175663c 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -103,13 +103,13 @@ func statusFunc(cmd *cobra.Command, args []string) error { var statusOutputString string switch { case detailFlag: - statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder) + statusOutputString = outputInformationHolder.FullDetailSummary() case jsonFlag: - statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder) + statusOutputString, err = outputInformationHolder.JSON() case yamlFlag: - statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder) + statusOutputString, err = outputInformationHolder.YAML() default: - statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false) + statusOutputString = outputInformationHolder.GeneralSummary(false, false, false, false) } if err != nil { diff --git a/client/embed/embed.go b/client/embed/embed.go index 353c5438f..43089fc9d 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/profilemanager" sshcommon "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" + mgmProto "github.com/netbirdio/netbird/shared/management/proto" ) var ( @@ -38,6 +39,7 @@ type Client struct { setupKey string jwtToken string connect *internal.ConnectClient + recorder *peer.Status } // Options configures a new Client. @@ -161,11 +163,17 @@ func New(opts Options) (*Client, error) { func (c *Client) Start(startCtx context.Context) error { c.mu.Lock() defer c.mu.Unlock() - if c.cancel != nil { + if c.connect != nil { return ErrClientAlreadyStarted } - ctx := internal.CtxInitState(context.Background()) + ctx, cancel := context.WithCancel(internal.CtxInitState(context.Background())) + defer func() { + if c.connect == nil { + cancel() + } + }() + // nolint:staticcheck ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName) if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil { @@ -173,7 +181,9 @@ func (c *Client) Start(startCtx context.Context) error { } recorder := peer.NewRecorder(c.config.ManagementURL.String()) + c.recorder = recorder client := internal.NewConnectClient(ctx, c.config, recorder, false) + client.SetSyncResponsePersistence(true) // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available @@ -197,6 +207,7 @@ func (c *Client) Start(startCtx context.Context) error { } c.connect = client + c.cancel = cancel return nil } @@ -211,17 +222,23 @@ func (c *Client) Stop(ctx context.Context) error { return ErrClientNotStarted } + if c.cancel != nil { + c.cancel() + c.cancel = nil + } + done := make(chan error, 1) + connect := c.connect go func() { - done <- c.connect.Stop() + done <- connect.Stop() }() select { case <-ctx.Done(): - c.cancel = nil + c.connect = nil return ctx.Err() case err := <-done: - c.cancel = nil + c.connect = nil if err != nil { return fmt.Errorf("stop: %w", err) } @@ -315,6 +332,62 @@ func (c *Client) NewHTTPClient() *http.Client { } } +// Status returns the current status of the client. +func (c *Client) Status() (peer.FullStatus, error) { + c.mu.Lock() + recorder := c.recorder + connect := c.connect + c.mu.Unlock() + + if recorder == nil { + return peer.FullStatus{}, errors.New("client not started") + } + + if connect != nil { + engine := connect.Engine() + if engine != nil { + _ = engine.RunHealthProbes(false) + } + } + + return recorder.GetFullStatus(), nil +} + +// GetLatestSyncResponse returns the latest sync response from the management server. +func (c *Client) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) { + engine, err := c.getEngine() + if err != nil { + return nil, err + } + + syncResp, err := engine.GetLatestSyncResponse() + if err != nil { + return nil, fmt.Errorf("get sync response: %w", err) + } + + return syncResp, nil +} + +// SetLogLevel sets the logging level for the client and its components. +func (c *Client) SetLogLevel(levelStr string) error { + level, err := logrus.ParseLevel(levelStr) + if err != nil { + return fmt.Errorf("parse log level: %w", err) + } + + logrus.SetLevel(level) + + c.mu.Lock() + connect := c.connect + c.mu.Unlock() + + if connect != nil { + connect.SetLogLevel(level) + } + + return nil +} + // VerifySSHHostKey verifies an SSH host key against stored peer keys. // Returns nil if the key matches, ErrPeerNotFound if peer is not in network, // ErrNoStoredKey if peer has no stored key, or an error for verification failures. diff --git a/client/internal/connect.go b/client/internal/connect.go index 017c8bf10..65637c073 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -420,6 +420,19 @@ func (c *ConnectClient) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) return syncResponse, nil } +// SetLogLevel sets the log level for the firewall manager if the engine is running. +func (c *ConnectClient) SetLogLevel(level log.Level) { + engine := c.Engine() + if engine == nil { + return + } + + fwManager := engine.GetFirewallManager() + if fwManager != nil { + fwManager.SetLogLevel(level) + } +} + // Status returns the current client status func (c *ConnectClient) Status() StatusType { if c == nil { diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 29bb7f3dc..1ce7bf1c6 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -631,9 +631,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { handler, err := newUpstreamResolver( s.ctx, - s.wgInterface.Name(), - s.wgInterface.Address().IP, - s.wgInterface.Address().Network, + s.wgInterface, s.statusRecorder, s.hostsDNSHolder, nbdns.RootZone, @@ -743,9 +741,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai log.Debugf("creating handler for domain=%s with priority=%d", domainGroup.domain, priority) handler, err := newUpstreamResolver( s.ctx, - s.wgInterface.Name(), - s.wgInterface.Address().IP, - s.wgInterface.Address().Network, + s.wgInterface, s.statusRecorder, s.hostsDNSHolder, domainGroup.domain, @@ -926,9 +922,7 @@ func (s *DefaultServer) addHostRootZone() { handler, err := newUpstreamResolver( s.ctx, - s.wgInterface.Name(), - s.wgInterface.Address().IP, - s.wgInterface.Address().Network, + s.wgInterface, s.statusRecorder, s.hostsDNSHolder, nbdns.RootZone, diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 200a5f496..31e58b9f5 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -15,6 +15,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/firewall/uspfilter" @@ -81,6 +82,10 @@ func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) { return configurer.WGStats{}, nil } +func (w *mocWGIface) GetNet() *netstack.Net { + return nil +} + var zoneRecords = []nbdns.SimpleRecord{ { Name: "peera.netbird.cloud", diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index c997acc75..654d280ef 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -18,6 +18,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/dns/resutil" @@ -418,6 +419,56 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u return rm, t, nil } +// ExchangeWithNetstack performs a DNS exchange using netstack for dialing. +// This is needed when netstack is enabled to reach peer IPs through the tunnel. +func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) { + reply, err := netstackExchange(ctx, nsNet, r, upstream, "udp") + if err != nil { + return nil, err + } + + // If response is truncated, retry with TCP + if reply != nil && reply.MsgHdr.Truncated { + log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP", + r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + return netstackExchange(ctx, nsNet, r, upstream, "tcp") + } + + return reply, nil +} + +func netstackExchange(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream, network string) (*dns.Msg, error) { + conn, err := nsNet.DialContext(ctx, network, upstream) + if err != nil { + return nil, fmt.Errorf("with %s: %w", network, err) + } + defer func() { + if err := conn.Close(); err != nil { + log.Debugf("failed to close DNS connection: %v", err) + } + }() + + if deadline, ok := ctx.Deadline(); ok { + if err := conn.SetDeadline(deadline); err != nil { + return nil, fmt.Errorf("set deadline: %w", err) + } + } + + dnsConn := &dns.Conn{Conn: conn} + + if err := dnsConn.WriteMsg(r); err != nil { + return nil, fmt.Errorf("write %s message: %w", network, err) + } + + reply, err := dnsConn.ReadMsg() + if err != nil { + return nil, fmt.Errorf("read %s message: %w", network, err) + } + + return reply, nil +} + + // FormatPeerStatus formats peer connection status information for debugging DNS timeouts func FormatPeerStatus(peerState *peer.State) string { isConnected := peerState.ConnStatus == peer.StatusConnected diff --git a/client/internal/dns/upstream_android.go b/client/internal/dns/upstream_android.go index def281f28..d7cff377b 100644 --- a/client/internal/dns/upstream_android.go +++ b/client/internal/dns/upstream_android.go @@ -23,9 +23,7 @@ type upstreamResolver struct { // first time, and we need to wait for a while to start to use again the proper DNS resolver. func newUpstreamResolver( ctx context.Context, - _ string, - _ netip.Addr, - _ netip.Prefix, + _ WGIface, statusRecorder *peer.Status, hostsDNSHolder *hostsDNSHolder, domain string, diff --git a/client/internal/dns/upstream_general.go b/client/internal/dns/upstream_general.go index 434e5880b..1143b6c51 100644 --- a/client/internal/dns/upstream_general.go +++ b/client/internal/dns/upstream_general.go @@ -5,22 +5,23 @@ package dns import ( "context" "net/netip" + "runtime" "time" "github.com/miekg/dns" + "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/internal/peer" ) type upstreamResolver struct { *upstreamResolverBase + nsNet *netstack.Net } func newUpstreamResolver( ctx context.Context, - _ string, - _ netip.Addr, - _ netip.Prefix, + wgIface WGIface, statusRecorder *peer.Status, _ *hostsDNSHolder, domain string, @@ -28,12 +29,23 @@ func newUpstreamResolver( upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain) nonIOS := &upstreamResolver{ upstreamResolverBase: upstreamResolverBase, + nsNet: wgIface.GetNet(), } upstreamResolverBase.upstreamClient = nonIOS return nonIOS, nil } func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { + // TODO: Check if upstream DNS server is routed through a peer before using netstack. + // Similar to iOS logic, we should determine if the DNS server is reachable directly + // or needs to go through the tunnel, and only use netstack when necessary. + // For now, only use netstack on JS platform where direct access is not possible. + if u.nsNet != nil && runtime.GOOS == "js" { + start := time.Now() + reply, err := ExchangeWithNetstack(ctx, u.nsNet, r, upstream) + return reply, time.Since(start), err + } + client := &dns.Client{ Timeout: ClientTimeout, } diff --git a/client/internal/dns/upstream_ios.go b/client/internal/dns/upstream_ios.go index eadcdd117..4d053a5a1 100644 --- a/client/internal/dns/upstream_ios.go +++ b/client/internal/dns/upstream_ios.go @@ -26,9 +26,7 @@ type upstreamResolverIOS struct { func newUpstreamResolver( ctx context.Context, - interfaceName string, - ip netip.Addr, - net netip.Prefix, + wgIface WGIface, statusRecorder *peer.Status, _ *hostsDNSHolder, domain string, @@ -37,9 +35,9 @@ func newUpstreamResolver( ios := &upstreamResolverIOS{ upstreamResolverBase: upstreamResolverBase, - lIP: ip, - lNet: net, - interfaceName: interfaceName, + lIP: wgIface.Address().IP, + lNet: wgIface.Address().Network, + interfaceName: wgIface.Name(), } ios.upstreamClient = ios diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index e1573e75e..2852f4775 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -2,13 +2,17 @@ package dns import ( "context" + "net" "net/netip" "strings" "testing" "time" "github.com/miekg/dns" + "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/dns/test" ) @@ -58,7 +62,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.TODO()) - resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".") + resolver, _ := newUpstreamResolver(ctx, &mockNetstackProvider{}, nil, nil, ".") // Convert test servers to netip.AddrPort var servers []netip.AddrPort for _, server := range testCase.InputServers { @@ -112,6 +116,19 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) { } } +type mockNetstackProvider struct{} + +func (m *mockNetstackProvider) Name() string { return "mock" } +func (m *mockNetstackProvider) Address() wgaddr.Address { return wgaddr.Address{} } +func (m *mockNetstackProvider) ToInterface() *net.Interface { return nil } +func (m *mockNetstackProvider) IsUserspaceBind() bool { return false } +func (m *mockNetstackProvider) GetFilter() device.PacketFilter { return nil } +func (m *mockNetstackProvider) GetDevice() *device.FilteredDevice { return nil } +func (m *mockNetstackProvider) GetNet() *netstack.Net { return nil } +func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) { + return "", nil +} + type mockUpstreamResolver struct { r *dns.Msg rtt time.Duration diff --git a/client/internal/dns/wgiface.go b/client/internal/dns/wgiface.go index 28e9cebf1..717e16325 100644 --- a/client/internal/dns/wgiface.go +++ b/client/internal/dns/wgiface.go @@ -5,6 +5,8 @@ package dns import ( "net" + "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -17,4 +19,5 @@ type WGIface interface { IsUserspaceBind() bool GetFilter() device.PacketFilter GetDevice() *device.FilteredDevice + GetNet() *netstack.Net } diff --git a/client/internal/dns/wgiface_windows.go b/client/internal/dns/wgiface_windows.go index d1374fd54..347e0233a 100644 --- a/client/internal/dns/wgiface_windows.go +++ b/client/internal/dns/wgiface_windows.go @@ -1,6 +1,8 @@ package dns import ( + "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -12,5 +14,6 @@ type WGIface interface { IsUserspaceBind() bool GetFilter() device.PacketFilter GetDevice() *device.FilteredDevice + GetNet() *netstack.Net GetInterfaceGUIDString() (string, error) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 2acd86a16..0182b2530 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1748,22 +1748,26 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool { } e.syncMsgMux.Unlock() - var results []relay.ProbeResult - if waitForResult { - results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns) - } else { - results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns) - } - e.statusRecorder.UpdateRelayStates(results) + // Skip STUN/TURN probing for JS/WASM as it's not available relayHealthy := true - for _, res := range results { - if res.Err != nil { - relayHealthy = false - break + if runtime.GOOS != "js" { + var results []relay.ProbeResult + if waitForResult { + results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns) + } else { + results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns) } + e.statusRecorder.UpdateRelayStates(results) + + for _, res := range results { + if res.Err != nil { + relayHealthy = false + break + } + } + log.Debugf("relay health check: healthy=%t", relayHealthy) } - log.Debugf("relay health check: healthy=%t", relayHealthy) allHealthy := signalHealthy && managementHealthy && relayHealthy log.Debugf("all health checks completed: healthy=%t", allHealthy) diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 76f4f523c..697bda2ff 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -14,6 +14,7 @@ import ( "golang.org/x/exp/maps" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -158,6 +159,7 @@ type FullStatus struct { NSGroupStates []NSGroupState NumOfForwardingRules int LazyConnectionEnabled bool + Events []*proto.SystemEvent } type StatusChangeSubscription struct { @@ -981,6 +983,7 @@ func (d *Status) GetFullStatus() FullStatus { } fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...) + fullStatus.Events = d.GetEventHistory() return fullStatus } @@ -1181,3 +1184,97 @@ type EventSubscription struct { func (s *EventSubscription) Events() <-chan *proto.SystemEvent { return s.events } + +// ToProto converts FullStatus to proto.FullStatus. +func (fs FullStatus) ToProto() *proto.FullStatus { + pbFullStatus := proto.FullStatus{ + ManagementState: &proto.ManagementState{}, + SignalState: &proto.SignalState{}, + LocalPeerState: &proto.LocalPeerState{}, + Peers: []*proto.PeerState{}, + } + + pbFullStatus.ManagementState.URL = fs.ManagementState.URL + pbFullStatus.ManagementState.Connected = fs.ManagementState.Connected + if err := fs.ManagementState.Error; err != nil { + pbFullStatus.ManagementState.Error = err.Error() + } + + pbFullStatus.SignalState.URL = fs.SignalState.URL + pbFullStatus.SignalState.Connected = fs.SignalState.Connected + if err := fs.SignalState.Error; err != nil { + pbFullStatus.SignalState.Error = err.Error() + } + + pbFullStatus.LocalPeerState.IP = fs.LocalPeerState.IP + pbFullStatus.LocalPeerState.PubKey = fs.LocalPeerState.PubKey + pbFullStatus.LocalPeerState.KernelInterface = fs.LocalPeerState.KernelInterface + pbFullStatus.LocalPeerState.Fqdn = fs.LocalPeerState.FQDN + pbFullStatus.LocalPeerState.RosenpassPermissive = fs.RosenpassState.Permissive + pbFullStatus.LocalPeerState.RosenpassEnabled = fs.RosenpassState.Enabled + pbFullStatus.NumberOfForwardingRules = int32(fs.NumOfForwardingRules) + pbFullStatus.LazyConnectionEnabled = fs.LazyConnectionEnabled + + pbFullStatus.LocalPeerState.Networks = maps.Keys(fs.LocalPeerState.Routes) + + for _, peerState := range fs.Peers { + networks := maps.Keys(peerState.GetRoutes()) + + pbPeerState := &proto.PeerState{ + IP: peerState.IP, + PubKey: peerState.PubKey, + ConnStatus: peerState.ConnStatus.String(), + ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate), + Relayed: peerState.Relayed, + LocalIceCandidateType: peerState.LocalIceCandidateType, + RemoteIceCandidateType: peerState.RemoteIceCandidateType, + LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint, + RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint, + RelayAddress: peerState.RelayServerAddress, + Fqdn: peerState.FQDN, + LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake), + BytesRx: peerState.BytesRx, + BytesTx: peerState.BytesTx, + RosenpassEnabled: peerState.RosenpassEnabled, + Networks: networks, + Latency: durationpb.New(peerState.Latency), + SshHostKey: peerState.SSHHostKey, + } + pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) + } + + for _, relayState := range fs.Relays { + pbRelayState := &proto.RelayState{ + URI: relayState.URI, + Available: relayState.Err == nil, + } + if err := relayState.Err; err != nil { + pbRelayState.Error = err.Error() + } + pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState) + } + + for _, dnsState := range fs.NSGroupStates { + var err string + if dnsState.Error != nil { + err = dnsState.Error.Error() + } + + var servers []string + for _, server := range dnsState.Servers { + servers = append(servers, server.String()) + } + + pbDnsState := &proto.NSGroupState{ + Servers: servers, + Domains: dnsState.Domains, + Enabled: dnsState.Enabled, + Error: err, + } + pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState) + } + + pbFullStatus.Events = fs.Events + + return &pbFullStatus +} diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index c7ec47da4..12c9ff4af 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -17,13 +17,13 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/common" "github.com/netbirdio/netbird/client/internal/routemanager/fakeip" + iface "github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" @@ -38,11 +38,6 @@ type internalDNATer interface { AddInternalDNATMapping(netip.Addr, netip.Addr) error } -type wgInterface interface { - Name() string - Address() wgaddr.Address -} - type DnsInterceptor struct { mu sync.RWMutex route *route.Route @@ -52,7 +47,7 @@ type DnsInterceptor struct { dnsServer nbdns.Server currentPeerKey string interceptedDomains domainMap - wgInterface wgInterface + wgInterface iface.WGIface peerStore *peerstore.Store firewall firewall.Manager fakeIPManager *fakeip.Manager @@ -250,12 +245,6 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) - if err != nil { - d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err)) - return - } - if r.Extra == nil { r.MsgHdr.AuthenticatedData = true } @@ -264,20 +253,8 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() - startTime := time.Now() - reply, _, err := nbdns.ExchangeWithFallback(ctx, client, r, upstream) - if err != nil { - if errors.Is(err, context.DeadlineExceeded) { - elapsed := time.Since(startTime) - peerInfo := d.debugPeerTimeout(upstreamIP, peerKey) - logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v", - elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err) - } else { - logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) - } - if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { - logger.Errorf("failed writing DNS response: %v", err) - } + reply := d.queryUpstreamDNS(ctx, w, r, upstream, upstreamIP, peerKey, logger) + if reply == nil { return } @@ -586,6 +563,44 @@ func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toR return } +// queryUpstreamDNS queries the upstream DNS server using netstack if available, otherwise uses regular client. +// Returns the DNS reply on success, or nil on error (error responses are written internally). +func (d *DnsInterceptor) queryUpstreamDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream string, upstreamIP netip.Addr, peerKey string, logger *log.Entry) *dns.Msg { + startTime := time.Now() + + nsNet := d.wgInterface.GetNet() + var reply *dns.Msg + var err error + + if nsNet != nil { + reply, err = nbdns.ExchangeWithNetstack(ctx, nsNet, r, upstream) + } else { + client, clientErr := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), dnsTimeout) + if clientErr != nil { + d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", clientErr)) + return nil + } + reply, _, err = nbdns.ExchangeWithFallback(ctx, client, r, upstream) + } + + if err == nil { + return reply + } + + if errors.Is(err, context.DeadlineExceeded) { + elapsed := time.Since(startTime) + peerInfo := d.debugPeerTimeout(upstreamIP, peerKey) + logger.Errorf("peer DNS timeout after %v (timeout=%v) for domain=%s to peer %s (%s)%s - error: %v", + elapsed.Truncate(time.Millisecond), dnsTimeout, r.Question[0].Name, upstreamIP.String(), peerKey, peerInfo, err) + } else { + logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) + } + if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { + logger.Errorf("failed writing DNS response: %v", err) + } + return nil +} + func (d *DnsInterceptor) debugPeerTimeout(peerIP netip.Addr, peerKey string) string { if d.statusRecorder == nil { return "" diff --git a/client/internal/routemanager/iface/iface_common.go b/client/internal/routemanager/iface/iface_common.go index f844f4bed..9b7bce751 100644 --- a/client/internal/routemanager/iface/iface_common.go +++ b/client/internal/routemanager/iface/iface_common.go @@ -4,6 +4,8 @@ import ( "net" "net/netip" + "golang.zx2c4.com/wireguard/tun/netstack" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -18,4 +20,5 @@ type wgIfaceBase interface { IsUserspaceBind() bool GetFilter() device.PacketFilter GetDevice() *device.FilteredDevice + GetNet() *netstack.Net } diff --git a/client/server/debug.go b/client/server/debug.go index 056d9df21..dfad41604 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -173,20 +173,9 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) ( log.SetLevel(level) - if s.connectClient == nil { - return nil, fmt.Errorf("connect client not initialized") + if s.connectClient != nil { + s.connectClient.SetLogLevel(level) } - engine := s.connectClient.Engine() - if engine == nil { - return nil, fmt.Errorf("engine not initialized") - } - - fwManager := engine.GetFirewallManager() - if fwManager == nil { - return nil, fmt.Errorf("firewall manager not initialized") - } - - fwManager.SetLogLevel(level) log.Infof("Log level set to %s", level.String()) diff --git a/client/server/event.go b/client/server/event.go index 9a4e0fbf5..b5c12a3a6 100644 --- a/client/server/event.go +++ b/client/server/event.go @@ -1,8 +1,6 @@ package server import ( - "context" - log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/proto" @@ -29,8 +27,3 @@ func (s *Server) SubscribeEvents(req *proto.SubscribeRequest, stream proto.Daemo } } } - -func (s *Server) GetEvents(context.Context, *proto.GetEventsRequest) (*proto.GetEventsResponse, error) { - events := s.statusRecorder.GetEventHistory() - return &proto.GetEventsResponse{Events: events}, nil -} diff --git a/client/server/server.go b/client/server/server.go index 7b6c4e98c..d593b3f34 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -13,15 +13,12 @@ import ( "time" "github.com/cenkalti/backoff/v4" - "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/protobuf/types/known/durationpb" log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" gstatus "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/profilemanager" @@ -1067,11 +1064,9 @@ func (s *Server) Status( if msg.GetFullPeerStatus { s.runProbes(msg.ShouldRunProbes) fullStatus := s.statusRecorder.GetFullStatus() - pbFullStatus := toProtoFullStatus(fullStatus) + pbFullStatus := fullStatus.ToProto() pbFullStatus.Events = s.statusRecorder.GetEventHistory() - pbFullStatus.SshServerState = s.getSSHServerState() - statusResponse.FullStatus = pbFullStatus } @@ -1600,94 +1595,6 @@ func parseEnvDuration(envVar string, defaultDuration time.Duration) time.Duratio return defaultDuration } -func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { - pbFullStatus := proto.FullStatus{ - ManagementState: &proto.ManagementState{}, - SignalState: &proto.SignalState{}, - LocalPeerState: &proto.LocalPeerState{}, - Peers: []*proto.PeerState{}, - } - - pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL - pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected - if err := fullStatus.ManagementState.Error; err != nil { - pbFullStatus.ManagementState.Error = err.Error() - } - - pbFullStatus.SignalState.URL = fullStatus.SignalState.URL - pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected - if err := fullStatus.SignalState.Error; err != nil { - pbFullStatus.SignalState.Error = err.Error() - } - - pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP - pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey - pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface - pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN - pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive - pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled - pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes) - pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules) - pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled - - for _, peerState := range fullStatus.Peers { - pbPeerState := &proto.PeerState{ - IP: peerState.IP, - PubKey: peerState.PubKey, - ConnStatus: peerState.ConnStatus.String(), - ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate), - Relayed: peerState.Relayed, - LocalIceCandidateType: peerState.LocalIceCandidateType, - RemoteIceCandidateType: peerState.RemoteIceCandidateType, - LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint, - RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint, - RelayAddress: peerState.RelayServerAddress, - Fqdn: peerState.FQDN, - LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake), - BytesRx: peerState.BytesRx, - BytesTx: peerState.BytesTx, - RosenpassEnabled: peerState.RosenpassEnabled, - Networks: maps.Keys(peerState.GetRoutes()), - Latency: durationpb.New(peerState.Latency), - SshHostKey: peerState.SSHHostKey, - } - pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) - } - - for _, relayState := range fullStatus.Relays { - pbRelayState := &proto.RelayState{ - URI: relayState.URI, - Available: relayState.Err == nil, - } - if err := relayState.Err; err != nil { - pbRelayState.Error = err.Error() - } - pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState) - } - - for _, dnsState := range fullStatus.NSGroupStates { - var err string - if dnsState.Error != nil { - err = dnsState.Error.Error() - } - - var servers []string - for _, server := range dnsState.Servers { - servers = append(servers, server.String()) - } - - pbDnsState := &proto.NSGroupState{ - Servers: servers, - Domains: dnsState.Domains, - Enabled: dnsState.Enabled, - Error: err, - } - pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState) - } - - return &pbFullStatus -} - // sendTerminalNotification sends a terminal notification message // to inform the user that the NetBird connection session has expired. func sendTerminalNotification() error { diff --git a/client/status/status.go b/client/status/status.go index 4f31f3637..305797eee 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -325,61 +325,64 @@ func sortPeersByIP(peersStateDetail []PeerStateDetailOutput) { } } -func ParseToJSON(overview OutputOverview) (string, error) { - jsonBytes, err := json.Marshal(overview) +// JSON returns the status overview as a JSON string. +func (o *OutputOverview) JSON() (string, error) { + jsonBytes, err := json.Marshal(o) if err != nil { return "", fmt.Errorf("json marshal failed") } return string(jsonBytes), err } -func ParseToYAML(overview OutputOverview) (string, error) { - yamlBytes, err := yaml.Marshal(overview) +// YAML returns the status overview as a YAML string. +func (o *OutputOverview) YAML() (string, error) { + yamlBytes, err := yaml.Marshal(o) if err != nil { return "", fmt.Errorf("yaml marshal failed") } return string(yamlBytes), nil } -func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string { +// GeneralSummary returns a general summary of the status overview. +func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameServers bool, showSSHSessions bool) string { var managementConnString string - if overview.ManagementState.Connected { + if o.ManagementState.Connected { managementConnString = "Connected" if showURL { - managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL) + managementConnString = fmt.Sprintf("%s to %s", managementConnString, o.ManagementState.URL) } } else { managementConnString = "Disconnected" - if overview.ManagementState.Error != "" { - managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error) + if o.ManagementState.Error != "" { + managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, o.ManagementState.Error) } } var signalConnString string - if overview.SignalState.Connected { + if o.SignalState.Connected { signalConnString = "Connected" if showURL { - signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL) + signalConnString = fmt.Sprintf("%s to %s", signalConnString, o.SignalState.URL) } } else { signalConnString = "Disconnected" - if overview.SignalState.Error != "" { - signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error) + if o.SignalState.Error != "" { + signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, o.SignalState.Error) } } interfaceTypeString := "Userspace" - interfaceIP := overview.IP - if overview.KernelInterface { + interfaceIP := o.IP + if o.KernelInterface { interfaceTypeString = "Kernel" - } else if overview.IP == "" { + } else if o.IP == "" { interfaceTypeString = "N/A" interfaceIP = "N/A" } var relaysString string if showRelays { - for _, relay := range overview.Relays.Details { + for _, relay := range o.Relays.Details { available := "Available" reason := "" @@ -395,18 +398,18 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) } } else { - relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total) + relaysString = fmt.Sprintf("%d/%d Available", o.Relays.Available, o.Relays.Total) } networks := "-" - if len(overview.Networks) > 0 { - sort.Strings(overview.Networks) - networks = strings.Join(overview.Networks, ", ") + if len(o.Networks) > 0 { + sort.Strings(o.Networks) + networks = strings.Join(o.Networks, ", ") } var dnsServersString string if showNameServers { - for _, nsServerGroup := range overview.NSServerGroups { + for _, nsServerGroup := range o.NSServerGroups { enabled := "Available" if !nsServerGroup.Enabled { enabled = "Unavailable" @@ -430,25 +433,25 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, ) } } else { - dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups)) + dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(o.NSServerGroups), len(o.NSServerGroups)) } rosenpassEnabledStatus := "false" - if overview.RosenpassEnabled { + if o.RosenpassEnabled { rosenpassEnabledStatus = "true" - if overview.RosenpassPermissive { + if o.RosenpassPermissive { rosenpassEnabledStatus = "true (permissive)" //nolint:gosec } } lazyConnectionEnabledStatus := "false" - if overview.LazyConnectionEnabled { + if o.LazyConnectionEnabled { lazyConnectionEnabledStatus = "true" } sshServerStatus := "Disabled" - if overview.SSHServerState.Enabled { - sessionCount := len(overview.SSHServerState.Sessions) + if o.SSHServerState.Enabled { + sessionCount := len(o.SSHServerState.Sessions) if sessionCount > 0 { sessionWord := "session" if sessionCount > 1 { @@ -460,7 +463,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, } if showSSHSessions && sessionCount > 0 { - for _, session := range overview.SSHServerState.Sessions { + for _, session := range o.SSHServerState.Sessions { var sessionDisplay string if session.JWTUsername != "" { sessionDisplay = fmt.Sprintf("[%s@%s -> %s] %s", @@ -484,7 +487,7 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, } } - peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) + peersCountString := fmt.Sprintf("%d/%d Connected", o.Peers.Connected, o.Peers.Total) goos := runtime.GOOS goarch := runtime.GOARCH @@ -512,30 +515,31 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, "Forwarding rules: %d\n"+ "Peers count: %s\n", fmt.Sprintf("%s/%s%s", goos, goarch, goarm), - overview.DaemonVersion, + o.DaemonVersion, version.NetbirdVersion(), - overview.ProfileName, + o.ProfileName, managementConnString, signalConnString, relaysString, dnsServersString, - domain.Domain(overview.FQDN).SafeString(), + domain.Domain(o.FQDN).SafeString(), interfaceIP, interfaceTypeString, rosenpassEnabledStatus, lazyConnectionEnabledStatus, sshServerStatus, networks, - overview.NumberOfForwardingRules, + o.NumberOfForwardingRules, peersCountString, ) return summary } -func ParseToFullDetailSummary(overview OutputOverview) string { - parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive) - parsedEventsString := parseEvents(overview.Events) - summary := ParseGeneralSummary(overview, true, true, true, true) +// FullDetailSummary returns a full detailed summary with peer details and events. +func (o *OutputOverview) FullDetailSummary() string { + parsedPeersString := parsePeers(o.Peers, o.RosenpassEnabled, o.RosenpassPermissive) + parsedEventsString := parseEvents(o.Events) + summary := o.GeneralSummary(true, true, true, true) return fmt.Sprintf( "Peers detail:"+ diff --git a/client/status/status_test.go b/client/status/status_test.go index 1dca1e5b1..f4585827b 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -268,7 +268,7 @@ func TestSortingOfPeers(t *testing.T) { } func TestParsingToJSON(t *testing.T) { - jsonString, _ := ParseToJSON(overview) + jsonString, _ := overview.JSON() //@formatter:off expectedJSONString := ` @@ -404,7 +404,7 @@ func TestParsingToJSON(t *testing.T) { } func TestParsingToYAML(t *testing.T) { - yaml, _ := ParseToYAML(overview) + yaml, _ := overview.YAML() expectedYAML := `peers: @@ -511,7 +511,7 @@ func TestParsingToDetail(t *testing.T) { lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate) lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake) - detail := ParseToFullDetailSummary(overview) + detail := overview.FullDetailSummary() expectedDetail := fmt.Sprintf( `Peers detail: @@ -575,7 +575,7 @@ Peers count: 2/2 Connected } func TestParsingToShortVersion(t *testing.T) { - shortVersion := ParseGeneralSummary(overview, false, false, false, false) + shortVersion := overview.GeneralSummary(false, false, false, false) expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + ` Daemon version: 0.14.1 diff --git a/client/ui/debug.go b/client/ui/debug.go index 51fa28575..a057b2a85 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -441,7 +441,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName) - postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) + postUpStatusOutput = overview.FullDetailSummary() } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, postUpStatusOutput) @@ -458,7 +458,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName) - preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) + preDownStatusOutput = overview.FullDetailSummary() } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), params.duration) @@ -595,7 +595,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName) - statusOutput = nbstatus.ParseToFullDetailSummary(overview) + statusOutput = overview.FullDetailSummary() } request := &proto.DebugBundleRequest{ diff --git a/client/wasm/cmd/main.go b/client/wasm/cmd/main.go index 238e272fa..2647c2f0d 100644 --- a/client/wasm/cmd/main.go +++ b/client/wasm/cmd/main.go @@ -9,20 +9,29 @@ import ( "time" log "github.com/sirupsen/logrus" + "google.golang.org/protobuf/encoding/protojson" netbird "github.com/netbirdio/netbird/client/embed" + "github.com/netbirdio/netbird/client/proto" sshdetection "github.com/netbirdio/netbird/client/ssh/detection" + nbstatus "github.com/netbirdio/netbird/client/status" "github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/ssh" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/version" ) const ( clientStartTimeout = 30 * time.Second clientStopTimeout = 10 * time.Second + pingTimeout = 10 * time.Second defaultLogLevel = "warn" defaultSSHDetectionTimeout = 20 * time.Second + + icmpEchoRequest = 8 + icmpCodeEcho = 0 + pingBufferSize = 1500 ) func main() { @@ -113,18 +122,45 @@ func createStopMethod(client *netbird.Client) js.Func { }) } +// validateSSHArgs validates SSH connection arguments +func validateSSHArgs(args []js.Value) (host string, port int, username string, err js.Value) { + if len(args) < 2 { + return "", 0, "", js.ValueOf("error: requires host and port") + } + + if args[0].Type() != js.TypeString { + return "", 0, "", js.ValueOf("host parameter must be a string") + } + if args[1].Type() != js.TypeNumber { + return "", 0, "", js.ValueOf("port parameter must be a number") + } + + host = args[0].String() + port = args[1].Int() + username = "root" + + if len(args) > 2 { + if args[2].Type() == js.TypeString && args[2].String() != "" { + username = args[2].String() + } else if args[2].Type() != js.TypeString { + return "", 0, "", js.ValueOf("username parameter must be a string") + } + } + + return host, port, username, js.Undefined() +} + // createSSHMethod creates the SSH connection method func createSSHMethod(client *netbird.Client) js.Func { return js.FuncOf(func(this js.Value, args []js.Value) any { - if len(args) < 2 { - return js.ValueOf("error: requires host and port") - } - - host := args[0].String() - port := args[1].Int() - username := "root" - if len(args) > 2 && args[2].String() != "" { - username = args[2].String() + host, port, username, validationErr := validateSSHArgs(args) + if !validationErr.IsUndefined() { + if validationErr.Type() == js.TypeString && validationErr.String() == "error: requires host and port" { + return validationErr + } + return createPromise(func(resolve, reject js.Value) { + reject.Invoke(validationErr) + }) } var jwtToken string @@ -154,6 +190,110 @@ func createSSHMethod(client *netbird.Client) js.Func { }) } +func performPing(client *netbird.Client, hostname string) { + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + + start := time.Now() + conn, err := client.Dial(ctx, "ping", hostname) + if err != nil { + js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s failed: %v", hostname, err)) + return + } + defer func() { + if err := conn.Close(); err != nil { + log.Debugf("failed to close ping connection: %v", err) + } + }() + + icmpData := make([]byte, 8) + icmpData[0] = icmpEchoRequest + icmpData[1] = icmpCodeEcho + + if _, err := conn.Write(icmpData); err != nil { + js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s write failed: %v", hostname, err)) + return + } + + buf := make([]byte, pingBufferSize) + if _, err := conn.Read(buf); err != nil { + js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s read failed: %v", hostname, err)) + return + } + + latency := time.Since(start) + js.Global().Get("console").Call("log", fmt.Sprintf("Ping to %s: %dms", hostname, latency.Milliseconds())) +} + +func performPingTCP(client *netbird.Client, hostname string, port int) { + ctx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + + address := fmt.Sprintf("%s:%d", hostname, port) + start := time.Now() + conn, err := client.Dial(ctx, "tcp", address) + if err != nil { + js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s failed: %v", address, err)) + return + } + latency := time.Since(start) + + if err := conn.Close(); err != nil { + log.Debugf("failed to close TCP connection: %v", err) + } + + js.Global().Get("console").Call("log", fmt.Sprintf("TCP ping to %s succeeded: %dms", address, latency.Milliseconds())) +} + +// createPingMethod creates the ping method +func createPingMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf("error: hostname required") + } + + if args[0].Type() != js.TypeString { + return createPromise(func(resolve, reject js.Value) { + reject.Invoke(js.ValueOf("hostname parameter must be a string")) + }) + } + + hostname := args[0].String() + return createPromise(func(resolve, reject js.Value) { + performPing(client, hostname) + resolve.Invoke(js.Undefined()) + }) + }) +} + +// createPingTCPMethod creates the pingtcp method +func createPingTCPMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 2 { + return js.ValueOf("error: hostname and port required") + } + + if args[0].Type() != js.TypeString { + return createPromise(func(resolve, reject js.Value) { + reject.Invoke(js.ValueOf("hostname parameter must be a string")) + }) + } + + if args[1].Type() != js.TypeNumber { + return createPromise(func(resolve, reject js.Value) { + reject.Invoke(js.ValueOf("port parameter must be a number")) + }) + } + + hostname := args[0].String() + port := args[1].Int() + return createPromise(func(resolve, reject js.Value) { + performPingTCP(client, hostname, port) + resolve.Invoke(js.Undefined()) + }) + }) +} + // createProxyRequestMethod creates the proxyRequest method func createProxyRequestMethod(client *netbird.Client) js.Func { return js.FuncOf(func(this js.Value, args []js.Value) any { @@ -162,6 +302,11 @@ func createProxyRequestMethod(client *netbird.Client) js.Func { } request := args[0] + if request.Type() != js.TypeObject { + return createPromise(func(resolve, reject js.Value) { + reject.Invoke(js.ValueOf("request parameter must be an object")) + }) + } return createPromise(func(resolve, reject js.Value) { response, err := http.ProxyRequest(client, request) @@ -181,11 +326,145 @@ func createRDPProxyMethod(client *netbird.Client) js.Func { return js.ValueOf("error: hostname and port required") } + if args[0].Type() != js.TypeString { + return createPromise(func(resolve, reject js.Value) { + reject.Invoke(js.ValueOf("hostname parameter must be a string")) + }) + } + if args[1].Type() != js.TypeString { + return createPromise(func(resolve, reject js.Value) { + reject.Invoke(js.ValueOf("port parameter must be a string")) + }) + } + proxy := rdp.NewRDCleanPathProxy(client) return proxy.CreateProxy(args[0].String(), args[1].String()) }) } +// getStatusOverview is a helper to get the status overview +func getStatusOverview(client *netbird.Client) (nbstatus.OutputOverview, error) { + fullStatus, err := client.Status() + if err != nil { + return nbstatus.OutputOverview{}, err + } + + pbFullStatus := fullStatus.ToProto() + statusResp := &proto.StatusResponse{ + DaemonVersion: version.NetbirdVersion(), + FullStatus: pbFullStatus, + } + + return nbstatus.ConvertToStatusOutputOverview(statusResp, false, "", nil, nil, nil, "", ""), nil +} + +// createStatusMethod creates the status method that returns JSON +func createStatusMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + return createPromise(func(resolve, reject js.Value) { + overview, err := getStatusOverview(client) + if err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + + jsonStr, err := overview.JSON() + if err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + jsonObj := js.Global().Get("JSON").Call("parse", jsonStr) + resolve.Invoke(jsonObj) + }) + }) +} + +// createStatusSummaryMethod creates the statusSummary method +func createStatusSummaryMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + return createPromise(func(resolve, reject js.Value) { + overview, err := getStatusOverview(client) + if err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + + summary := overview.GeneralSummary(false, false, false, false) + js.Global().Get("console").Call("log", summary) + resolve.Invoke(js.Undefined()) + }) + }) +} + +// createStatusDetailMethod creates the statusDetail method +func createStatusDetailMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + return createPromise(func(resolve, reject js.Value) { + overview, err := getStatusOverview(client) + if err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + + detail := overview.FullDetailSummary() + js.Global().Get("console").Call("log", detail) + resolve.Invoke(js.Undefined()) + }) + }) +} + +// createGetSyncResponseMethod creates the getSyncResponse method that returns the latest sync response as JSON +func createGetSyncResponseMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + return createPromise(func(resolve, reject js.Value) { + syncResp, err := client.GetLatestSyncResponse() + if err != nil { + reject.Invoke(js.ValueOf(err.Error())) + return + } + + options := protojson.MarshalOptions{ + EmitUnpopulated: true, + UseProtoNames: true, + AllowPartial: true, + } + jsonBytes, err := options.Marshal(syncResp) + if err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("marshal sync response: %v", err))) + return + } + + jsonObj := js.Global().Get("JSON").Call("parse", string(jsonBytes)) + resolve.Invoke(jsonObj) + }) + }) +} + +// createSetLogLevelMethod creates the setLogLevel method to dynamically change logging level +func createSetLogLevelMethod(client *netbird.Client) js.Func { + return js.FuncOf(func(_ js.Value, args []js.Value) any { + if len(args) < 1 { + return js.ValueOf("error: log level required") + } + + if args[0].Type() != js.TypeString { + return createPromise(func(resolve, reject js.Value) { + reject.Invoke(js.ValueOf("log level parameter must be a string")) + }) + } + + logLevel := args[0].String() + return createPromise(func(resolve, reject js.Value) { + if err := client.SetLogLevel(logLevel); err != nil { + reject.Invoke(js.ValueOf(fmt.Sprintf("set log level: %v", err))) + return + } + log.Infof("Log level set to: %s", logLevel) + resolve.Invoke(js.ValueOf(true)) + }) + }) +} + // createPromise is a helper to create JavaScript promises func createPromise(handler func(resolve, reject js.Value)) js.Value { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { @@ -237,17 +516,24 @@ func createClientObject(client *netbird.Client) js.Value { obj["start"] = createStartMethod(client) obj["stop"] = createStopMethod(client) + obj["ping"] = createPingMethod(client) + obj["pingtcp"] = createPingTCPMethod(client) obj["detectSSHServerType"] = createDetectSSHServerMethod(client) obj["createSSHConnection"] = createSSHMethod(client) obj["proxyRequest"] = createProxyRequestMethod(client) obj["createRDPProxy"] = createRDPProxyMethod(client) + obj["status"] = createStatusMethod(client) + obj["statusSummary"] = createStatusSummaryMethod(client) + obj["statusDetail"] = createStatusDetailMethod(client) + obj["getSyncResponse"] = createGetSyncResponseMethod(client) + obj["setLogLevel"] = createSetLogLevelMethod(client) return js.ValueOf(obj) } // netBirdClientConstructor acts as a JavaScript constructor function -func netBirdClientConstructor(this js.Value, args []js.Value) any { - return js.Global().Get("Promise").New(js.FuncOf(func(this js.Value, promiseArgs []js.Value) any { +func netBirdClientConstructor(_ js.Value, args []js.Value) any { + return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { resolve := promiseArgs[0] reject := promiseArgs[1]