diff --git a/client/android/client.go b/client/android/client.go index c05246569..678f5d9d5 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -112,7 +112,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, "") return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener) } @@ -138,7 +138,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, "") return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener) } diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 18f3547ca..0e5e527a0 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -308,7 +308,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string { cmd.PrintErrf("Failed to get status: %v\n", err) } else { statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), + nbstatus.ConvertToStatusOutputOverview(statusResp.GetFullStatus(), anon, statusResp.GetDaemonVersion(), "", nil, nil, nil, "", ""), ) } return statusOutputString diff --git a/client/cmd/status.go b/client/cmd/status.go index 723f2367c..01816e9d8 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -99,7 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { profName = activeProf.Name } - var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName) + var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName) var statusOutputString string switch { case detailFlag: diff --git a/client/cmd/up.go b/client/cmd/up.go index 1fa58e6ed..7cc342fe0 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -196,7 +196,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr r := peer.NewRecorder(config.ManagementURL.String()) r.GetFullStatus() - connectClient := internal.NewConnectClient(ctx, config, r) + //todo: do we need to pass logFile here ? + connectClient := internal.NewConnectClient(ctx, config, r, "") SetupDebugHandler(ctx, config, r, connectClient, "") return connectClient.Run(nil) diff --git a/client/embed/embed.go b/client/embed/embed.go index de83f9d96..79f5f0e43 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -131,7 +131,9 @@ func (c *Client) Start(startCtx context.Context) error { } recorder := peer.NewRecorder(c.config.ManagementURL.String()) - client := internal.NewConnectClient(ctx, c.config, recorder) + + //todo: do we need to pass logFile here ? + client := internal.NewConnectClient(ctx, c.config, recorder, "") // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available diff --git a/client/internal/connect.go b/client/internal/connect.go index 523dcaf1f..b62a2d951 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -45,17 +45,19 @@ type ConnectClient struct { engineMutex sync.Mutex persistSyncResponse bool + LogFile string } func NewConnectClient( ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, - + logFile string, ) *ConnectClient { return &ConnectClient{ ctx: ctx, config: config, + LogFile: logFile, statusRecorder: statusRecorder, engineMutex: sync.Mutex{}, } @@ -261,7 +263,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan peerConfig := loginResp.GetPeerConfig() - engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig) + engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, c.LogFile) if err != nil { log.Error(err) return wrapErr(err) @@ -270,7 +272,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan checks := loginResp.GetChecks() c.engineMutex.Lock() - c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks) + c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, c.config) c.engine.SetSyncResponsePersistence(c.persistSyncResponse) c.engineMutex.Unlock() @@ -415,7 +417,7 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) { } // createEngineConfig converts configuration received from Management Service to EngineConfig -func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { +func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logFile string) (*EngineConfig, error) { nm := false if config.NetworkMonitor != nil { nm = *config.NetworkMonitor @@ -444,6 +446,9 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf BlockInbound: config.BlockInbound, LazyConnectionEnabled: config.LazyConnectionEnabled, + LogFile: logFile, + + ProfileConfig: config, } if config.PreSharedKey != "" { diff --git a/client/internal/debug/upload.go b/client/internal/debug/upload.go new file mode 100644 index 000000000..cdf52409d --- /dev/null +++ b/client/internal/debug/upload.go @@ -0,0 +1,101 @@ +package debug + +import ( + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + + "github.com/netbirdio/netbird/upload-server/types" +) + +const maxBundleUploadSize = 50 * 1024 * 1024 + +func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) { + response, err := getUploadURL(ctx, url, managementURL) + if err != nil { + return "", err + } + + err = upload(ctx, filePath, response) + if err != nil { + return "", err + } + return response.Key, nil +} + +func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error { + fileData, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("open file: %w", err) + } + + defer fileData.Close() + + stat, err := fileData.Stat() + if err != nil { + return fmt.Errorf("stat file: %w", err) + } + + if stat.Size() > maxBundleUploadSize { + return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize) + } + + req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData) + if err != nil { + return fmt.Errorf("create PUT request: %w", err) + } + + req.ContentLength = stat.Size() + req.Header.Set("Content-Type", "application/octet-stream") + + putResp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("upload failed: %v", err) + } + defer putResp.Body.Close() + + if putResp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(putResp.Body) + return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body)) + } + return nil +} + +func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) { + id := getURLHash(managementURL) + getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil) + if err != nil { + return nil, fmt.Errorf("create GET request: %w", err) + } + + getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue) + + resp, err := http.DefaultClient.Do(getReq) + if err != nil { + return nil, fmt.Errorf("get presigned URL: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body)) + } + + urlBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response body: %w", err) + } + var response types.GetURLResponse + if err := json.Unmarshal(urlBytes, &response); err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + return &response, nil +} + +func getURLHash(url string) string { + return fmt.Sprintf("%x", sha256.Sum256([]byte(url))) +} diff --git a/client/server/debug_test.go b/client/internal/debug/upload_test.go similarity index 93% rename from client/server/debug_test.go rename to client/internal/debug/upload_test.go index 53d9ac8ed..e833c196d 100644 --- a/client/server/debug_test.go +++ b/client/internal/debug/upload_test.go @@ -1,4 +1,4 @@ -package server +package debug import ( "context" @@ -38,7 +38,7 @@ func TestUpload(t *testing.T) { fileContent := []byte("test file content") err := os.WriteFile(file, fileContent, 0640) require.NoError(t, err) - key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file) + key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file) require.NoError(t, err) id := getURLHash(testURL) require.Contains(t, key, id+"/") diff --git a/client/internal/engine.go b/client/internal/engine.go index 14f5e9ae8..4e847758d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -32,6 +32,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/acl" + "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/ingressgw" @@ -48,11 +49,13 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/jobexec" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/shared/management/domain" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" nbssh "github.com/netbirdio/netbird/client/ssh" + nbstatus "github.com/netbirdio/netbird/client/status" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" @@ -63,6 +66,7 @@ import ( signal "github.com/netbirdio/netbird/shared/signal/client" sProto "github.com/netbirdio/netbird/shared/signal/proto" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/version" ) // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. @@ -125,6 +129,11 @@ type EngineConfig struct { BlockInbound bool LazyConnectionEnabled bool + + // for debug bundle generation + ProfileConfig *profilemanager.Config + + LogFile string } // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. @@ -189,11 +198,15 @@ type Engine struct { stateManager *statemanager.Manager srWatcher *guard.SRWatcher - // Sync response persistence + // Sync response persistence (protected by syncRespMux) + syncRespMux sync.RWMutex persistSyncResponse bool latestSyncResponse *mgmProto.SyncResponse connSemaphore *semaphoregroup.SemaphoreGroup flowManager nftypes.FlowManager + + jobExecutor *jobexec.Executor + jobExecutorWG sync.WaitGroup } // Peer is an instance of the Connection Peer @@ -207,17 +220,7 @@ type localIpUpdater interface { } // NewEngine creates a new Connection Engine with probes attached -func NewEngine( - clientCtx context.Context, - clientCancel context.CancelFunc, - signalClient signal.Client, - mgmClient mgm.Client, - relayManager *relayClient.Manager, - config *EngineConfig, - mobileDep MobileDependency, - statusRecorder *peer.Status, - checks []*mgmProto.Checks, -) *Engine { +func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, c *profilemanager.Config) *Engine { engine := &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, @@ -236,6 +239,7 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), + jobExecutor: jobexec.NewExecutor(), } sm := profilemanager.NewServiceManager("") @@ -314,6 +318,8 @@ func (e *Engine) Stop() error { e.cancel() } + e.jobExecutorWG.Wait() // block until job goroutines finish + // very ugly but we want to remove peers from the WireGuard interface first before removing interface. // Removing peers happens in the conn.Close() asynchronously time.Sleep(500 * time.Millisecond) @@ -699,9 +705,18 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return nil } + // Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux. + // Read the storage-enabled flag under the syncRespMux too. + e.syncRespMux.RLock() + enabled := e.persistSyncResponse + e.syncRespMux.RUnlock() + // Store sync response if persistence is enabled - if e.persistSyncResponse { + if enabled { + e.syncRespMux.Lock() e.latestSyncResponse = update + e.syncRespMux.Unlock() + log.Debugf("sync response persisted with serial %d", nm.GetSerial()) } @@ -886,20 +901,27 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { return nil } - func (e *Engine) receiveJobEvents() { + e.jobExecutorWG.Add(1) go func() { + defer e.jobExecutorWG.Done() err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse { - // Simple test handler — replace with real logic - log.Infof("Received job request: %+v", msg) - // TODO: trigger local debug bundle or other job - return &mgmProto.JobResponse{ - ID: msg.ID, - WorkloadResults: &mgmProto.JobResponse_Bundle{ - Bundle: &mgmProto.BundleResult{ - UploadKey: "upload-key", - }, - }, + resp := mgmProto.JobResponse{ + ID: msg.ID, + Status: mgmProto.JobStatus_failed, + } + switch params := msg.WorkloadParameters.(type) { + case *mgmProto.JobRequest_Bundle: + bundleResult, err := e.handleBundle(params.Bundle) + if err != nil { + resp.Reason = []byte(err.Error()) + return &resp + } + resp.Status = mgmProto.JobStatus_succeeded + resp.WorkloadResults = bundleResult + return &resp + default: + return nil } }) if err != nil { @@ -914,6 +936,49 @@ func (e *Engine) receiveJobEvents() { log.Debugf("connecting to Management Service jobs stream") } +func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobResponse_Bundle, error) { + syncResponse, err := e.GetLatestSyncResponse() + if err != nil { + return nil, fmt.Errorf("get latest sync response: %w", err) + } + + if syncResponse == nil { + return nil, errors.New("sync response is not available") + } + + // convert fullStatus to statusOutput + fullStatus := e.statusRecorder.GetFullStatus() + protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus) + overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, params.Anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", "") + statusOutput := nbstatus.ParseToFullDetailSummary(overview) + + bundleDeps := debug.GeneratorDependencies{ + InternalConfig: e.config.ProfileConfig, + StatusRecorder: e.statusRecorder, + SyncResponse: syncResponse, + LogFile: e.config.LogFile, + } + + bundleJobParams := debug.BundleConfig{ + Anonymize: params.Anonymize, + ClientStatus: statusOutput, + IncludeSystemInfo: true, + LogFileCount: uint32(params.LogFileCount), + } + + uploadKey, err := e.jobExecutor.BundleJob(e.ctx, bundleDeps, bundleJobParams, e.config.ProfileConfig.ManagementURL.String()) + if err != nil { + return nil, err + } + + response := &mgmProto.JobResponse_Bundle{ + Bundle: &mgmProto.BundleResult{ + UploadKey: uploadKey, + }, + } + return response, nil +} + // receiveManagementEvents connects to the Management Service event stream to receive updates from the management service // E.g. when a new peer has been registered and we are allowed to connect to it. func (e *Engine) receiveManagementEvents() { @@ -1761,8 +1826,8 @@ func (e *Engine) stopDNSServer() { // SetSyncResponsePersistence enables or disables sync response persistence func (e *Engine) SetSyncResponsePersistence(enabled bool) { - e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() + e.syncRespMux.Lock() + defer e.syncRespMux.Unlock() if enabled == e.persistSyncResponse { return @@ -1777,20 +1842,22 @@ func (e *Engine) SetSyncResponsePersistence(enabled bool) { // GetLatestSyncResponse returns the stored sync response if persistence is enabled func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) { - e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() + e.syncRespMux.RLock() + enabled := e.persistSyncResponse + latest := e.latestSyncResponse + e.syncRespMux.RUnlock() - if !e.persistSyncResponse { + if !enabled { return nil, errors.New("sync response persistence is disabled") } - if e.latestSyncResponse == nil { + if latest == nil { //nolint:nilnil return nil, nil } - log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(e.latestSyncResponse)) - sr, ok := proto.Clone(e.latestSyncResponse).(*mgmProto.SyncResponse) + log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest)) + sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse) if !ok { return nil, fmt.Errorf("failed to clone sync response") } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index c2c9ae84a..1a179c6ce 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -27,6 +27,7 @@ import ( "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/groups" @@ -219,22 +220,13 @@ func TestEngine_SSH(t *testing.T) { defer cancel() relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) - engine := NewEngine( - ctx, cancel, - &signal.MockClient{}, - &mgmt.MockClient{}, - relayMgr, - &EngineConfig{ - WgIfaceName: "utun101", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - ServerSSHAllowed: true, - }, - MobileDependency{}, - peer.NewRecorder("https://mgm"), - nil, - ) + engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + WgIfaceName: "utun101", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + ServerSSHAllowed: true, + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, @@ -364,20 +356,12 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { defer cancel() relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) - engine := NewEngine( - ctx, cancel, - &signal.MockClient{}, - &mgmt.MockClient{}, - relayMgr, - &EngineConfig{ - WgIfaceName: "utun102", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - }, - MobileDependency{}, - peer.NewRecorder("https://mgm"), - nil) + engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ + WgIfaceName: "utun102", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) wgIface := &MockWGIface{ NameFunc: func() string { return "utun102" }, @@ -595,7 +579,7 @@ func TestEngine_Sync(t *testing.T) { WgAddr: "100.64.0.1/24", WgPrivateKey: key, WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx engine.dnsServer = &dns.MockServer{ @@ -759,7 +743,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx newNet, err := stdnet.NewNet() if err != nil { @@ -960,7 +944,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { WgAddr: wgAddr, WgPrivateKey: key, WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil) engine.ctx = ctx newNet, err := stdnet.NewNet() @@ -1484,7 +1468,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin } relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) - e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil + e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil e.ctx = ctx return e, err } diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 2109d4b15..e50962f9e 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -20,8 +20,8 @@ import ( "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) // ConnectionListener export internal Listener for mobile @@ -127,7 +127,8 @@ func (c *Client) Run(fd int32, interfaceName string) error { c.onHostDnsFn = func([]string) {} cfg.WgIface = interfaceName - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) + //todo: do we need to pass logFile here ? + c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, "") return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } diff --git a/client/jobexec/executor.go b/client/jobexec/executor.go new file mode 100644 index 000000000..4d64d6fbf --- /dev/null +++ b/client/jobexec/executor.go @@ -0,0 +1,35 @@ +package jobexec + +import ( + "context" + "fmt" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/debug" + "github.com/netbirdio/netbird/upload-server/types" +) + +type Executor struct { +} + +func NewExecutor() *Executor { + return &Executor{} +} + +func (e *Executor) BundleJob(ctx context.Context, debugBundleDependencies debug.GeneratorDependencies, params debug.BundleConfig, mgmURL string) (string, error) { + bundleGenerator := debug.NewBundleGenerator(debugBundleDependencies, params) + + path, err := bundleGenerator.Generate() + if err != nil { + return "", fmt.Errorf("generate debug bundle: %w", err) + } + + key, err := debug.UploadDebugBundle(ctx, types.DefaultBundleURL, mgmURL, path) + if err != nil { + log.Errorf("failed to upload debug bundle to %v", err) + return "", fmt.Errorf("upload debug bundle: %w", err) + } + + return key, nil +} diff --git a/client/server/debug.go b/client/server/debug.go index 056d9df21..330b476e6 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -4,24 +4,16 @@ package server import ( "context" - "crypto/sha256" - "encoding/json" "errors" "fmt" - "io" - "net/http" - "os" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto" - "github.com/netbirdio/netbird/upload-server/types" ) -const maxBundleUploadSize = 50 * 1024 * 1024 - // DebugBundle creates a debug bundle and returns the location. func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) { s.mutex.Lock() @@ -55,7 +47,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( if req.GetUploadURL() == "" { return &proto.DebugBundleResponse{Path: path}, nil } - key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path) + key, err := debug.UploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path) if err != nil { log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err) return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil @@ -66,92 +58,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil } -func uploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) { - response, err := getUploadURL(ctx, url, managementURL) - if err != nil { - return "", err - } - - err = upload(ctx, filePath, response) - if err != nil { - return "", err - } - return response.Key, nil -} - -func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error { - fileData, err := os.Open(filePath) - if err != nil { - return fmt.Errorf("open file: %w", err) - } - - defer fileData.Close() - - stat, err := fileData.Stat() - if err != nil { - return fmt.Errorf("stat file: %w", err) - } - - if stat.Size() > maxBundleUploadSize { - return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize) - } - - req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData) - if err != nil { - return fmt.Errorf("create PUT request: %w", err) - } - - req.ContentLength = stat.Size() - req.Header.Set("Content-Type", "application/octet-stream") - - putResp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("upload failed: %v", err) - } - defer putResp.Body.Close() - - if putResp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(putResp.Body) - return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body)) - } - return nil -} - -func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) { - id := getURLHash(managementURL) - getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil) - if err != nil { - return nil, fmt.Errorf("create GET request: %w", err) - } - - getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue) - - resp, err := http.DefaultClient.Do(getReq) - if err != nil { - return nil, fmt.Errorf("get presigned URL: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body)) - } - - urlBytes, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response body: %w", err) - } - var response types.GetURLResponse - if err := json.Unmarshal(urlBytes, &response); err != nil { - return nil, fmt.Errorf("unmarshal response: %w", err) - } - return &response, nil -} - -func getURLHash(url string) string { - return fmt.Sprintf("%x", sha256.Sum256([]byte(url))) -} - // GetLogLevel gets the current logging level for the server. func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) { s.mutex.Lock() diff --git a/client/server/server.go b/client/server/server.go index f2e8dc12a..dd842d099 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -13,15 +13,12 @@ import ( "time" "github.com/cenkalti/backoff/v4" - "golang.org/x/exp/maps" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/protobuf/types/known/durationpb" log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" gstatus "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/timestamppb" "github.com/netbirdio/netbird/client/internal/auth" "github.com/netbirdio/netbird/client/internal/profilemanager" @@ -32,6 +29,7 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" + nbstatus "github.com/netbirdio/netbird/client/status" "github.com/netbirdio/netbird/version" ) @@ -235,7 +233,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanage runOperation := func() error { log.Tracef("running client connection") - s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) + s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, s.logFile) s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse) err := s.connectClient.Run(runningChan) @@ -1026,7 +1024,7 @@ func (s *Server) Status( } fullStatus := s.statusRecorder.GetFullStatus() - pbFullStatus := toProtoFullStatus(fullStatus) + pbFullStatus := nbstatus.ToProtoFullStatus(fullStatus) pbFullStatus.Events = s.statusRecorder.GetEventHistory() statusResponse.FullStatus = pbFullStatus } @@ -1131,93 +1129,6 @@ func (s *Server) onSessionExpire() { } } -func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { - pbFullStatus := proto.FullStatus{ - ManagementState: &proto.ManagementState{}, - SignalState: &proto.SignalState{}, - LocalPeerState: &proto.LocalPeerState{}, - Peers: []*proto.PeerState{}, - } - - pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL - pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected - if err := fullStatus.ManagementState.Error; err != nil { - pbFullStatus.ManagementState.Error = err.Error() - } - - pbFullStatus.SignalState.URL = fullStatus.SignalState.URL - pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected - if err := fullStatus.SignalState.Error; err != nil { - pbFullStatus.SignalState.Error = err.Error() - } - - pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP - pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey - pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface - pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN - pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive - pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled - pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes) - pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules) - pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled - - for _, peerState := range fullStatus.Peers { - pbPeerState := &proto.PeerState{ - IP: peerState.IP, - PubKey: peerState.PubKey, - ConnStatus: peerState.ConnStatus.String(), - ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate), - Relayed: peerState.Relayed, - LocalIceCandidateType: peerState.LocalIceCandidateType, - RemoteIceCandidateType: peerState.RemoteIceCandidateType, - LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint, - RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint, - RelayAddress: peerState.RelayServerAddress, - Fqdn: peerState.FQDN, - LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake), - BytesRx: peerState.BytesRx, - BytesTx: peerState.BytesTx, - RosenpassEnabled: peerState.RosenpassEnabled, - Networks: maps.Keys(peerState.GetRoutes()), - Latency: durationpb.New(peerState.Latency), - } - pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) - } - - for _, relayState := range fullStatus.Relays { - pbRelayState := &proto.RelayState{ - URI: relayState.URI, - Available: relayState.Err == nil, - } - if err := relayState.Err; err != nil { - pbRelayState.Error = err.Error() - } - pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState) - } - - for _, dnsState := range fullStatus.NSGroupStates { - var err string - if dnsState.Error != nil { - err = dnsState.Error.Error() - } - - var servers []string - for _, server := range dnsState.Servers { - servers = append(servers, server.String()) - } - - pbDnsState := &proto.NSGroupState{ - Servers: servers, - Domains: dnsState.Domains, - Enabled: dnsState.Enabled, - Error: err, - } - pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState) - } - - return &pbFullStatus -} - // sendTerminalNotification sends a terminal notification message // to inform the user that the NetBird connection session has expired. func sendTerminalNotification() error { diff --git a/client/status/status.go b/client/status/status.go index db5b7dc0b..df132a42f 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -11,6 +11,8 @@ import ( "strings" "time" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" "gopkg.in/yaml.v3" "github.com/netbirdio/netbird/client/anonymize" @@ -18,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/version" + "golang.org/x/exp/maps" ) type PeerStateDetailOutput struct { @@ -101,9 +104,7 @@ type OutputOverview struct { ProfileName string `json:"profileName" yaml:"profileName"` } -func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview { - pbFullStatus := resp.GetFullStatus() - +func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, daemonVersion string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview { managementState := pbFullStatus.GetManagementState() managementOverview := ManagementStateOutput{ URL: managementState.GetURL(), @@ -119,12 +120,12 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status } relayOverview := mapRelays(pbFullStatus.GetRelays()) - peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) + peersOverview := mapPeers(pbFullStatus.GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter) overview := OutputOverview{ Peers: peersOverview, CliVersion: version.NetbirdVersion(), - DaemonVersion: resp.GetDaemonVersion(), + DaemonVersion: daemonVersion, ManagementState: managementOverview, SignalState: signalOverview, Relays: relayOverview, @@ -458,6 +459,93 @@ func ParseToFullDetailSummary(overview OutputOverview) string { ) } +func ToProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { + pbFullStatus := proto.FullStatus{ + ManagementState: &proto.ManagementState{}, + SignalState: &proto.SignalState{}, + LocalPeerState: &proto.LocalPeerState{}, + Peers: []*proto.PeerState{}, + } + + pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL + pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected + if err := fullStatus.ManagementState.Error; err != nil { + pbFullStatus.ManagementState.Error = err.Error() + } + + pbFullStatus.SignalState.URL = fullStatus.SignalState.URL + pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected + if err := fullStatus.SignalState.Error; err != nil { + pbFullStatus.SignalState.Error = err.Error() + } + + pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP + pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey + pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface + pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN + pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive + pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled + pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes) + pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules) + pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled + + for _, peerState := range fullStatus.Peers { + pbPeerState := &proto.PeerState{ + IP: peerState.IP, + PubKey: peerState.PubKey, + ConnStatus: peerState.ConnStatus.String(), + ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate), + Relayed: peerState.Relayed, + LocalIceCandidateType: peerState.LocalIceCandidateType, + RemoteIceCandidateType: peerState.RemoteIceCandidateType, + LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint, + RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint, + RelayAddress: peerState.RelayServerAddress, + Fqdn: peerState.FQDN, + LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake), + BytesRx: peerState.BytesRx, + BytesTx: peerState.BytesTx, + RosenpassEnabled: peerState.RosenpassEnabled, + Networks: maps.Keys(peerState.GetRoutes()), + Latency: durationpb.New(peerState.Latency), + } + pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) + } + + for _, relayState := range fullStatus.Relays { + pbRelayState := &proto.RelayState{ + URI: relayState.URI, + Available: relayState.Err == nil, + } + if err := relayState.Err; err != nil { + pbRelayState.Error = err.Error() + } + pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState) + } + + for _, dnsState := range fullStatus.NSGroupStates { + var err string + if dnsState.Error != nil { + err = dnsState.Error.Error() + } + + var servers []string + for _, server := range dnsState.Servers { + servers = append(servers, server.String()) + } + + pbDnsState := &proto.NSGroupState{ + Servers: servers, + Domains: dnsState.Domains, + Enabled: dnsState.Enabled, + Error: err, + } + pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState) + } + + return &pbFullStatus +} + func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string { var ( peersString = "" @@ -737,3 +825,4 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) { } } } + diff --git a/client/status/status_test.go b/client/status/status_test.go index 660efd9ef..5c40938a0 100644 --- a/client/status/status_test.go +++ b/client/status/status_test.go @@ -234,7 +234,7 @@ var overview = OutputOverview{ } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { - convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "", "") + convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), false, resp.GetDaemonVersion(), "", nil, nil, nil, "", "") assert.Equal(t, overview, convertedResult) } diff --git a/client/ui/debug.go b/client/ui/debug.go index 76afc7753..3d12faf94 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -433,7 +433,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus.GetFullStatus(), params.anonymize, postUpStatus.GetDaemonVersion(), "", nil, nil, nil, "", "") postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) @@ -450,7 +450,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus.GetFullStatus(), params.anonymize, preDownStatus.GetDaemonVersion(), "", nil, nil, nil, "", "") preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", @@ -581,7 +581,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(statusResp.GetFullStatus(), anonymize, statusResp.GetDaemonVersion(), "", nil, nil, nil, "", "") statusOutput = nbstatus.ParseToFullDetailSummary(overview) } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 4d57a8095..65e931f18 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -163,11 +163,11 @@ func (s *GRPCServer) Job(srv proto.ManagementService_JobServer) error { } // Start background response handler - s.startResponseReceiver(ctx, accountID, srv) + s.startResponseReceiver(ctx, srv) // Prepare per-peer state - updates := s.jobManager.CreateJobChannel(peer.ID) - log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart)) + updates := s.jobManager.CreateJobChannel(ctx, accountID, peer.ID) + log.WithContext(ctx).Debugf("Job: took %v", time.Since(reqStart)) // Main loop: forward jobs to client return s.sendJobsLoop(ctx, accountID, peerKey, peer, updates, srv) @@ -262,7 +262,7 @@ func (s *GRPCServer) handleHandshake(ctx context.Context, srv proto.ManagementSe return peerKey, nil } -func (s *GRPCServer) startResponseReceiver(ctx context.Context, accountID string, srv proto.ManagementService_JobServer) { +func (s *GRPCServer) startResponseReceiver(ctx context.Context, srv proto.ManagementService_JobServer) { go func() { for { msg, err := srv.Recv() @@ -280,7 +280,7 @@ func (s *GRPCServer) startResponseReceiver(ctx context.Context, accountID string continue } - if err := s.jobManager.HandleResponse(ctx, accountID, jobResp); err != nil { + if err := s.jobManager.HandleResponse(ctx, jobResp); err != nil { log.WithContext(ctx).Errorf("handle job response failed: %v", err) } diff --git a/management/server/jobChannel.go b/management/server/jobChannel.go index e9c98d2b8..3dbbe0e2c 100644 --- a/management/server/jobChannel.go +++ b/management/server/jobChannel.go @@ -8,7 +8,9 @@ import ( "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/proto" + log "github.com/sirupsen/logrus" ) const jobChannelBuffer = 100 @@ -17,7 +19,6 @@ type JobEvent struct { PeerID string Request *proto.JobRequest Response *proto.JobResponse - Done chan struct{} // closed when response arrives } type JobManager struct { @@ -42,9 +43,11 @@ func NewJobManager(metrics telemetry.AppMetrics, store store.Store) *JobManager } // CreateJobChannel creates or replaces a channel for a peer -func (jm *JobManager) CreateJobChannel(peerID string) chan *JobEvent { - // TODO: all pending jobs stored in db for this peer should be failed - // jm.Store.MarkPendingJobsAsFailed(peerID) +func (jm *JobManager) CreateJobChannel(ctx context.Context, accountID, peerID string) chan *JobEvent { + // all pending jobs stored in db for this peer should be failed + if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, peerID, "Pending job cleanup: marked as failed automatically due to being stuck too long"); err != nil { + log.WithContext(ctx).Error(err.Error()) + } jm.mu.Lock() defer jm.mu.Unlock() @@ -71,7 +74,6 @@ func (jm *JobManager) SendJob(ctx context.Context, accountID, peerID string, req event := &JobEvent{ PeerID: peerID, Request: req, - Done: make(chan struct{}), } jm.mu.Lock() @@ -80,14 +82,6 @@ func (jm *JobManager) SendJob(ctx context.Context, accountID, peerID string, req select { case ch <- event: - case <-time.After(5 * time.Second): - jm.cleanup(ctx, accountID, string(req.ID), "timed out") - return fmt.Errorf("job channel full for peer %s", peerID) - } - - select { - case <-event.Done: - return nil case <-time.After(jm.responseWait): jm.cleanup(ctx, accountID, string(req.ID), "timed out") return fmt.Errorf("job %s timed out", req.ID) @@ -95,25 +89,32 @@ func (jm *JobManager) SendJob(ctx context.Context, accountID, peerID string, req jm.cleanup(ctx, accountID, string(req.ID), ctx.Err().Error()) return ctx.Err() } + return nil } // HandleResponse marks a job as finished and moves it to completed -func (jm *JobManager) HandleResponse(ctx context.Context, accountID string, resp *proto.JobResponse) error { +func (jm *JobManager) HandleResponse(ctx context.Context, resp *proto.JobResponse) error { jm.mu.Lock() defer jm.mu.Unlock() - event, ok := jm.pending[string(resp.ID)] + jobID := string(resp.ID) + + event, ok := jm.pending[jobID] if !ok { - return fmt.Errorf("job %s not found", resp.ID) + return fmt.Errorf("job %s not found", jobID) + } + var job types.Job + if err := job.ApplyResponse(resp); err != nil { + return fmt.Errorf("invalid job response: %v", err) + } + //update or create the store for job response + err := jm.Store.CompletePeerJob(ctx, &job) + if err == nil { + event.Response = resp } - event.Response = resp - //TODO: update the store for job response - // jm.store.CompleteJob(ctx,accountID, string(resp.GetID()), string(resp.GetResult()),string(resp.GetReason())) - close(event.Done) - delete(jm.pending, string(resp.ID)) - - return nil + delete(jm.pending, jobID) + return err } // CloseChannel closes a peer’s channel and cleans up its jobs @@ -130,7 +131,9 @@ func (jm *JobManager) CloseChannel(ctx context.Context, accountID, peerID string for jobID, ev := range jm.pending { if ev.PeerID == peerID { // if the client disconnect and there is pending job then marke it as failed - // jm.store.CompleteJob(ctx,accountID, jobID,"", "Time out ") + if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, peerID, "Time out peer disconnected"); err != nil { + log.WithContext(ctx).Errorf(err.Error()) + } delete(jm.pending, jobID) } } @@ -142,8 +145,29 @@ func (jm *JobManager) cleanup(ctx context.Context, accountID, jobID string, reas defer jm.mu.Unlock() if ev, ok := jm.pending[jobID]; ok { - close(ev.Done) - // jm.store.CompleteJob(ctx, accountID, jobID, "", reason) + if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, ev.PeerID, reason); err != nil { + log.WithContext(ctx).Errorf(err.Error()) + } delete(jm.pending, jobID) } } + +func (jm *JobManager) IsPeerConnected(peerID string) bool { + jm.mu.RLock() + defer jm.mu.RUnlock() + + _, ok := jm.jobChannels[peerID] + return ok +} + +func (jm *JobManager) IsPeerHasPendingJobs(peerID string) bool { + jm.mu.RLock() + defer jm.mu.RUnlock() + + for _, ev := range jm.pending { + if ev.PeerID == peerID { + return true + } + } + return false +} diff --git a/management/server/peer.go b/management/server/peer.go index 876a51124..16abf2b40 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -353,22 +353,24 @@ func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, p } // check if peer connected - // todo: implement jobManager.IsPeerConnected - // if !am.jobManager.IsPeerConnected(ctx, peerID) { - // return status.NewJobFailedError("peer not connected") - // } + if !am.jobManager.IsPeerConnected(peerID) { + return status.Errorf(status.BadRequest, "peer not connected") + } // check if already has pending jobs - // todo: implement jobManager.GetPendingJobsByPeerID - // if pending := am.jobManager.GetPendingJobsByPeerID(ctx, peerID); len(pending) > 0 { - // return status.NewJobAlreadyPendingError(peerID) - // } + if am.jobManager.IsPeerHasPendingJobs(peerID) { + return status.Errorf(status.BadRequest, "peer already hase pending job") + } + + jobStream, err := job.ToStreamJobRequest() + if err != nil { + return status.Errorf(status.BadRequest, "invalid job request %v", err) + } // try sending job first - // todo: implement am.jobManager.SendJob - // if err := am.jobManager.SendJob(ctx, peerID, job); err != nil { - // return status.NewJobFailedError(fmt.Sprintf("failed to send job: %v", err)) - // } + if err := am.jobManager.SendJob(ctx, accountID, peerID, jobStream); err != nil { + return status.Errorf(status.Internal, "failed to send job: %v", err) + } var peer *nbpeer.Peer var eventsToStore func() diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index de3a01e72..f27eddb2f 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -136,18 +136,35 @@ func (s *SqlStore) CreatePeerJob(ctx context.Context, job *types.Job) error { return nil } -// job was pending for too long and has been cancelled -// todo call it when we first start the jobChannel to make sure no stuck jobs -func (s *SqlStore) MarkPendingJobsAsFailed(ctx context.Context, peerID string) error { - now := time.Now().UTC() - return s.db. +func (s *SqlStore) CompletePeerJob(ctx context.Context, job *types.Job) error { + result := s.db. Model(&types.Job{}). - Where("peer_id = ? AND status = ?", types.JobStatusPending, peerID). - Updates(map[string]any{ - "status": types.JobStatusFailed, - "failed_reason": "Pending job cleanup: marked as failed automatically due to being stuck too long", - "completed_at": now, - }).Error + Where(idQueryCondition, job.ID). + Updates(job) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to update job in store: %s", result.Error) + return status.Errorf(status.Internal, "failed to create job in store") + } + return nil +} + +// job was pending for too long and has been cancelled +func (s *SqlStore) MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error { + now := time.Now().UTC() + result := s.db. + Model(&types.Job{}). + Where(accountAndPeerIDQueryCondition+"AND status = ?", accountID, peerID, types.JobStatusPending). + Updates(types.Job{ + Status: types.JobStatusFailed, + FailedReason: reason, + CompletedAt: &now, + }) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to mark pending jobs as Failed job in store: %s", result.Error) + return status.Errorf(status.Internal, "failed to mark pending job as Failed in store") + } + return nil } // GetJobByID fetches job by ID @@ -159,7 +176,11 @@ func (s *SqlStore) GetPeerJobByID(ctx context.Context, accountID, jobID string) if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "job %s not found", jobID) } - return &job, err + if err != nil { + log.WithContext(ctx).Errorf("failed to fetch job from store: %s", err) + return nil, err + } + return &job, nil } // get all jobs @@ -171,34 +192,13 @@ func (s *SqlStore) GetPeerJobs(ctx context.Context, accountID, peerID string) ([ Find(&jobs).Error if err != nil { + log.WithContext(ctx).Errorf("failed to fetch jobs from store: %s", err) return nil, err } return jobs, nil } -func (s *SqlStore) CompletePeerJob(accountID, jobID, result, failedReason string) error { - now := time.Now().UTC() - - updates := map[string]any{ - "completed_at": now, - } - - if result != "" && failedReason == "" { - updates["status"] = types.JobStatusSucceeded - updates["result"] = result - updates["failed_reason"] = "" - } else { - updates["status"] = types.JobStatusFailed - updates["failed_reason"] = failedReason - } - - return s.db. - Model(&types.Job{}). - Where(accountAndIDQueryCondition, accountID, jobID). - Updates(updates).Error -} - // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { log.WithContext(ctx).Tracef("acquiring global lock") diff --git a/management/server/store/store.go b/management/server/store/store.go index 98b9ae865..d8566e086 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -206,10 +206,10 @@ type Store interface { MarkAccountPrimary(ctx context.Context, accountID string) error UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error CreatePeerJob(ctx context.Context, job *types.Job) error - CompletePeerJob(accountID, jobID, result, failedReason string) error + CompletePeerJob(ctx context.Context, job *types.Job) error GetPeerJobByID(ctx context.Context, accountID, jobID string) (*types.Job, error) GetPeerJobs(ctx context.Context, accountID, peerID string) ([]*types.Job, error) - MarkPendingJobsAsFailed(ctx context.Context, peerID string) error + MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error } const ( diff --git a/management/server/types/job.go b/management/server/types/job.go index 484738790..d2973233c 100644 --- a/management/server/types/job.go +++ b/management/server/types/job.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -135,15 +136,15 @@ func validateAndBuildBundleParams(req api.WorkloadRequest, workload *Workload) e if err != nil { return fmt.Errorf("invalid parameters for bundle job") } - // validate bundle_for_time <= 5 minutes - if bundle.Parameters.BundleForTime < 0 || bundle.Parameters.BundleForTime > 5 { - return fmt.Errorf("bundle_for_time must be between 0 and 5, got %d", bundle.Parameters.BundleForTime) + // validate bundle_for_time <= 5 minutes if BundleFor is enabled + if bundle.Parameters.BundleFor && bundle.Parameters.BundleForTime < 1 || bundle.Parameters.BundleForTime > 5 { + return fmt.Errorf("bundle_for_time must be between 1 and 5, got %d", bundle.Parameters.BundleForTime) } // validate log-file-count ≥ 1 and ≤ 1000 if bundle.Parameters.LogFileCount < 1 || bundle.Parameters.LogFileCount > 1000 { return fmt.Errorf("log-file-count must be between 1 and 1000, got %d", bundle.Parameters.LogFileCount) } - + workload.Parameters, err = json.Marshal(bundle.Parameters) if err != nil { return fmt.Errorf("failed to marshal workload parameters: %w", err) @@ -153,3 +154,65 @@ func validateAndBuildBundleParams(req api.WorkloadRequest, workload *Workload) e return nil } + +// ApplyResponse validates and maps a proto.JobResponse into the Job fields. +func (j *Job) ApplyResponse(resp *proto.JobResponse) error { + if resp == nil { + return nil + } + + j.ID = string(resp.ID) + now := time.Now().UTC() + j.CompletedAt = &now + switch resp.Status { + case proto.JobStatus_succeeded: + j.Status = JobStatusSucceeded + case proto.JobStatus_failed: + j.Status = JobStatusFailed + default: + j.Status = JobStatusPending + } + + if len(resp.Reason) > 0 { + j.FailedReason = string(resp.Reason) + } + + // Handle workload results (oneof) + var err error + switch r := resp.WorkloadResults.(type) { + case *proto.JobResponse_Bundle: + if j.Workload.Result, err = json.Marshal(r.Bundle); err != nil { + return fmt.Errorf("failed to marshal workload results: %w", err) + } + default: + return fmt.Errorf("unsupported workload response type: %T", r) + } + return nil +} + +func (j *Job) ToStreamJobRequest() (*proto.JobRequest, error) { + switch j.Workload.Type { + case JobTypeBundle: + return j.buildStreamBundleResponse() + default: + return nil, status.Errorf(status.InvalidArgument, "unknown job type: %v", j.Workload.Type) + } +} + +func (j *Job) buildStreamBundleResponse() (*proto.JobRequest, error) { + var p api.BundleParameters + if err := json.Unmarshal(j.Workload.Parameters, &p); err != nil { + return nil, fmt.Errorf("invalid parameters for bundle job: %w", err) + } + return &proto.JobRequest{ + ID: []byte(j.ID), + WorkloadParameters: &proto.JobRequest_Bundle{ + Bundle: &proto.BundleParameters{ + BundleFor: p.BundleFor, + BundleForTime: int64(p.BundleForTime), + LogFileCount: int32(p.LogFileCount), + Anonymize: p.Anonymize, + }, + }, + }, nil +} diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index fe3910638..f5759ef21 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -180,13 +180,19 @@ func (c *GrpcClient) handleJobStream( // Main loop: receive, process, respond for { jobReq, err := c.receiveJobRequest(ctx, stream, serverPubKey) - if err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) { - log.WithContext(ctx).Info("job stream closed by server") + if err != nil && err != io.EOF { + c.notifyDisconnected(err) + s, _ := gstatus.FromError(err) + switch s.Code() { + case codes.PermissionDenied: + return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer + case codes.Canceled: + log.Debugf("management connection context has been canceled, this usually indicates shutdown") return nil + default: + log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err) + return err } - log.WithContext(ctx).Errorf("error receiving job request: %v", err) - return err } if jobReq == nil || len(jobReq.ID) == 0 { @@ -298,7 +304,7 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes. // blocking until error err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler) - if err != nil { + if err != nil && err != io.EOF { c.notifyDisconnected(err) s, _ := gstatus.FromError(err) switch s.Code() {