[management/client] Integrate Job API with Job Stream and Client Engine (#4428)

* integrate api

integrate api with stream and implement some client side

* fix typo and fix validation

* use real daemon address

* redo the connect via address

* Refactor the debug bundle generator to be ready to use from engine (#4469)

* fix tests

* fix lint

* fix bug with stream

* try refactor status 1

* fix convert fullStatus to statusOutput & add logFile

* fix tests

* fix tests

* fix not enough arguments in call to nbstatus.ConvertToStatusOutputOverview

* fix status_test

* fix(engine): avoid deadlock when stopping engine during debug bundle

* use atomic for lock-free

* use new lock

---------

Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
This commit is contained in:
Ali Amer
2025-09-23 14:07:48 +03:00
committed by GitHub
parent 3f6d95552f
commit acec87dd45
24 changed files with 562 additions and 365 deletions

View File

@@ -112,7 +112,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, "")
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
}
@@ -138,7 +138,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, "")
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener)
}

View File

@@ -308,7 +308,7 @@ func getStatusOutput(cmd *cobra.Command, anon bool) string {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
nbstatus.ConvertToStatusOutputOverview(statusResp.GetFullStatus(), anon, statusResp.GetDaemonVersion(), "", nil, nil, nil, "", ""),
)
}
return statusOutputString

View File

@@ -99,7 +99,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
profName = activeProf.Name
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
var statusOutputString string
switch {
case detailFlag:

View File

@@ -196,7 +196,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
//todo: do we need to pass logFile here ?
connectClient := internal.NewConnectClient(ctx, config, r, "")
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)

View File

@@ -131,7 +131,9 @@ func (c *Client) Start(startCtx context.Context) error {
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder)
//todo: do we need to pass logFile here ?
client := internal.NewConnectClient(ctx, c.config, recorder, "")
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available

View File

@@ -45,17 +45,19 @@ type ConnectClient struct {
engineMutex sync.Mutex
persistSyncResponse bool
LogFile string
}
func NewConnectClient(
ctx context.Context,
config *profilemanager.Config,
statusRecorder *peer.Status,
logFile string,
) *ConnectClient {
return &ConnectClient{
ctx: ctx,
config: config,
LogFile: logFile,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
}
@@ -261,7 +263,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, c.LogFile)
if err != nil {
log.Error(err)
return wrapErr(err)
@@ -270,7 +272,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, c.config)
c.engine.SetSyncResponsePersistence(c.persistSyncResponse)
c.engineMutex.Unlock()
@@ -415,7 +417,7 @@ func (c *ConnectClient) SetSyncResponsePersistence(enabled bool) {
}
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConfig *mgmProto.PeerConfig, logFile string) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
@@ -444,6 +446,9 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
BlockInbound: config.BlockInbound,
LazyConnectionEnabled: config.LazyConnectionEnabled,
LogFile: logFile,
ProfileConfig: config,
}
if config.PreSharedKey != "" {

View File

@@ -0,0 +1,101 @@
package debug
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"github.com/netbirdio/netbird/upload-server/types"
)
const maxBundleUploadSize = 50 * 1024 * 1024
func UploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
response, err := getUploadURL(ctx, url, managementURL)
if err != nil {
return "", err
}
err = upload(ctx, filePath, response)
if err != nil {
return "", err
}
return response.Key, nil
}
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
fileData, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer fileData.Close()
stat, err := fileData.Stat()
if err != nil {
return fmt.Errorf("stat file: %w", err)
}
if stat.Size() > maxBundleUploadSize {
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
}
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
if err != nil {
return fmt.Errorf("create PUT request: %w", err)
}
req.ContentLength = stat.Size()
req.Header.Set("Content-Type", "application/octet-stream")
putResp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("upload failed: %v", err)
}
defer putResp.Body.Close()
if putResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(putResp.Body)
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
}
return nil
}
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
id := getURLHash(managementURL)
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
resp, err := http.DefaultClient.Do(getReq)
if err != nil {
return nil, fmt.Errorf("get presigned URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
}
urlBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var response types.GetURLResponse
if err := json.Unmarshal(urlBytes, &response); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &response, nil
}
func getURLHash(url string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
}

View File

@@ -1,4 +1,4 @@
package server
package debug
import (
"context"
@@ -38,7 +38,7 @@ func TestUpload(t *testing.T) {
fileContent := []byte("test file content")
err := os.WriteFile(file, fileContent, 0640)
require.NoError(t, err)
key, err := uploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
key, err := UploadDebugBundle(context.Background(), testURL+types.GetURLPath, testURL, file)
require.NoError(t, err)
id := getURLHash(testURL)
require.Contains(t, key, id+"/")

View File

@@ -32,6 +32,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw"
@@ -48,11 +49,13 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/jobexec"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
nbssh "github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
@@ -63,6 +66,7 @@ import (
signal "github.com/netbirdio/netbird/shared/signal/client"
sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
)
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -125,6 +129,11 @@ type EngineConfig struct {
BlockInbound bool
LazyConnectionEnabled bool
// for debug bundle generation
ProfileConfig *profilemanager.Config
LogFile string
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -189,11 +198,15 @@ type Engine struct {
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
// Sync response persistence
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
jobExecutor *jobexec.Executor
jobExecutorWG sync.WaitGroup
}
// Peer is an instance of the Connection Peer
@@ -207,17 +220,7 @@ type localIpUpdater interface {
}
// NewEngine creates a new Connection Engine with probes attached
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, c *profilemanager.Config) *Engine {
engine := &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
@@ -236,6 +239,7 @@ func NewEngine(
statusRecorder: statusRecorder,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
jobExecutor: jobexec.NewExecutor(),
}
sm := profilemanager.NewServiceManager("")
@@ -314,6 +318,8 @@ func (e *Engine) Stop() error {
e.cancel()
}
e.jobExecutorWG.Wait() // block until job goroutines finish
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)
@@ -699,9 +705,18 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// Persist sync response under the dedicated lock (syncRespMux), not under syncMsgMux.
// Read the storage-enabled flag under the syncRespMux too.
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
e.syncRespMux.RUnlock()
// Store sync response if persistence is enabled
if e.persistSyncResponse {
if enabled {
e.syncRespMux.Lock()
e.latestSyncResponse = update
e.syncRespMux.Unlock()
log.Debugf("sync response persisted with serial %d", nm.GetSerial())
}
@@ -886,20 +901,27 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return nil
}
func (e *Engine) receiveJobEvents() {
e.jobExecutorWG.Add(1)
go func() {
defer e.jobExecutorWG.Done()
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
// Simple test handler — replace with real logic
log.Infof("Received job request: %+v", msg)
// TODO: trigger local debug bundle or other job
return &mgmProto.JobResponse{
ID: msg.ID,
WorkloadResults: &mgmProto.JobResponse_Bundle{
Bundle: &mgmProto.BundleResult{
UploadKey: "upload-key",
},
},
resp := mgmProto.JobResponse{
ID: msg.ID,
Status: mgmProto.JobStatus_failed,
}
switch params := msg.WorkloadParameters.(type) {
case *mgmProto.JobRequest_Bundle:
bundleResult, err := e.handleBundle(params.Bundle)
if err != nil {
resp.Reason = []byte(err.Error())
return &resp
}
resp.Status = mgmProto.JobStatus_succeeded
resp.WorkloadResults = bundleResult
return &resp
default:
return nil
}
})
if err != nil {
@@ -914,6 +936,49 @@ func (e *Engine) receiveJobEvents() {
log.Debugf("connecting to Management Service jobs stream")
}
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobResponse_Bundle, error) {
syncResponse, err := e.GetLatestSyncResponse()
if err != nil {
return nil, fmt.Errorf("get latest sync response: %w", err)
}
if syncResponse == nil {
return nil, errors.New("sync response is not available")
}
// convert fullStatus to statusOutput
fullStatus := e.statusRecorder.GetFullStatus()
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, params.Anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", "")
statusOutput := nbstatus.ParseToFullDetailSummary(overview)
bundleDeps := debug.GeneratorDependencies{
InternalConfig: e.config.ProfileConfig,
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogFile: e.config.LogFile,
}
bundleJobParams := debug.BundleConfig{
Anonymize: params.Anonymize,
ClientStatus: statusOutput,
IncludeSystemInfo: true,
LogFileCount: uint32(params.LogFileCount),
}
uploadKey, err := e.jobExecutor.BundleJob(e.ctx, bundleDeps, bundleJobParams, e.config.ProfileConfig.ManagementURL.String())
if err != nil {
return nil, err
}
response := &mgmProto.JobResponse_Bundle{
Bundle: &mgmProto.BundleResult{
UploadKey: uploadKey,
},
}
return response, nil
}
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
@@ -1761,8 +1826,8 @@ func (e *Engine) stopDNSServer() {
// SetSyncResponsePersistence enables or disables sync response persistence
func (e *Engine) SetSyncResponsePersistence(enabled bool) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.syncRespMux.Lock()
defer e.syncRespMux.Unlock()
if enabled == e.persistSyncResponse {
return
@@ -1777,20 +1842,22 @@ func (e *Engine) SetSyncResponsePersistence(enabled bool) {
// GetLatestSyncResponse returns the stored sync response if persistence is enabled
func (e *Engine) GetLatestSyncResponse() (*mgmProto.SyncResponse, error) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.syncRespMux.RLock()
enabled := e.persistSyncResponse
latest := e.latestSyncResponse
e.syncRespMux.RUnlock()
if !e.persistSyncResponse {
if !enabled {
return nil, errors.New("sync response persistence is disabled")
}
if e.latestSyncResponse == nil {
if latest == nil {
//nolint:nilnil
return nil, nil
}
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(e.latestSyncResponse))
sr, ok := proto.Clone(e.latestSyncResponse).(*mgmProto.SyncResponse)
log.Debugf("Retrieving latest sync response with size %d bytes", proto.Size(latest))
sr, ok := proto.Clone(latest).(*mgmProto.SyncResponse)
if !ok {
return nil, fmt.Errorf("failed to clone sync response")
}

View File

@@ -27,6 +27,7 @@ import (
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
@@ -219,22 +220,13 @@ func TestEngine_SSH(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
relayMgr,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil,
)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -364,20 +356,12 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
engine := NewEngine(
ctx, cancel,
&signal.MockClient{},
&mgmt.MockClient{},
relayMgr,
&EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
nil)
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
wgIface := &MockWGIface{
NameFunc: func() string { return "utun102" },
@@ -595,7 +579,7 @@ func TestEngine_Sync(t *testing.T) {
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
@@ -759,7 +743,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet()
if err != nil {
@@ -960,7 +944,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
WgAddr: wgAddr,
WgPrivateKey: key,
WgPort: 33100,
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
engine.ctx = ctx
newNet, err := stdnet.NewNet()
@@ -1484,7 +1468,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
e.ctx = ctx
return e, err
}

View File

@@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
// ConnectionListener export internal Listener for mobile
@@ -127,7 +127,8 @@ func (c *Client) Run(fd int32, interfaceName string) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
//todo: do we need to pass logFile here ?
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, "")
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
}

View File

@@ -0,0 +1,35 @@
package jobexec
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/upload-server/types"
)
type Executor struct {
}
func NewExecutor() *Executor {
return &Executor{}
}
func (e *Executor) BundleJob(ctx context.Context, debugBundleDependencies debug.GeneratorDependencies, params debug.BundleConfig, mgmURL string) (string, error) {
bundleGenerator := debug.NewBundleGenerator(debugBundleDependencies, params)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
key, err := debug.UploadDebugBundle(ctx, types.DefaultBundleURL, mgmURL, path)
if err != nil {
log.Errorf("failed to upload debug bundle to %v", err)
return "", fmt.Errorf("upload debug bundle: %w", err)
}
return key, nil
}

View File

@@ -4,24 +4,16 @@ package server
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/proto"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/upload-server/types"
)
const maxBundleUploadSize = 50 * 1024 * 1024
// DebugBundle creates a debug bundle and returns the location.
func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) {
s.mutex.Lock()
@@ -55,7 +47,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
if req.GetUploadURL() == "" {
return &proto.DebugBundleResponse{Path: path}, nil
}
key, err := uploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
key, err := debug.UploadDebugBundle(context.Background(), req.GetUploadURL(), s.config.ManagementURL.String(), path)
if err != nil {
log.Errorf("failed to upload debug bundle to %s: %v", req.GetUploadURL(), err)
return &proto.DebugBundleResponse{Path: path, UploadFailureReason: err.Error()}, nil
@@ -66,92 +58,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
return &proto.DebugBundleResponse{Path: path, UploadedKey: key}, nil
}
func uploadDebugBundle(ctx context.Context, url, managementURL, filePath string) (key string, err error) {
response, err := getUploadURL(ctx, url, managementURL)
if err != nil {
return "", err
}
err = upload(ctx, filePath, response)
if err != nil {
return "", err
}
return response.Key, nil
}
func upload(ctx context.Context, filePath string, response *types.GetURLResponse) error {
fileData, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("open file: %w", err)
}
defer fileData.Close()
stat, err := fileData.Stat()
if err != nil {
return fmt.Errorf("stat file: %w", err)
}
if stat.Size() > maxBundleUploadSize {
return fmt.Errorf("file size exceeds maximum limit of %d bytes", maxBundleUploadSize)
}
req, err := http.NewRequestWithContext(ctx, "PUT", response.URL, fileData)
if err != nil {
return fmt.Errorf("create PUT request: %w", err)
}
req.ContentLength = stat.Size()
req.Header.Set("Content-Type", "application/octet-stream")
putResp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("upload failed: %v", err)
}
defer putResp.Body.Close()
if putResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(putResp.Body)
return fmt.Errorf("upload status %d: %s", putResp.StatusCode, string(body))
}
return nil
}
func getUploadURL(ctx context.Context, url string, managementURL string) (*types.GetURLResponse, error) {
id := getURLHash(managementURL)
getReq, err := http.NewRequestWithContext(ctx, "GET", url+"?id="+id, nil)
if err != nil {
return nil, fmt.Errorf("create GET request: %w", err)
}
getReq.Header.Set(types.ClientHeader, types.ClientHeaderValue)
resp, err := http.DefaultClient.Do(getReq)
if err != nil {
return nil, fmt.Errorf("get presigned URL: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("get presigned URL status %d: %s", resp.StatusCode, string(body))
}
urlBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response body: %w", err)
}
var response types.GetURLResponse
if err := json.Unmarshal(urlBytes, &response); err != nil {
return nil, fmt.Errorf("unmarshal response: %w", err)
}
return &response, nil
}
func getURLHash(url string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(url)))
}
// GetLogLevel gets the current logging level for the server.
func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) {
s.mutex.Lock()

View File

@@ -13,15 +13,12 @@ import (
"time"
"github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -32,6 +29,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/version"
)
@@ -235,7 +233,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *profilemanage
runOperation := func() error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, s.logFile)
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
err := s.connectClient.Run(runningChan)
@@ -1026,7 +1024,7 @@ func (s *Server) Status(
}
fullStatus := s.statusRecorder.GetFullStatus()
pbFullStatus := toProtoFullStatus(fullStatus)
pbFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
pbFullStatus.Events = s.statusRecorder.GetEventHistory()
statusResponse.FullStatus = pbFullStatus
}
@@ -1131,93 +1129,6 @@ func (s *Server) onSessionExpire() {
}
}
func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
if err := fullStatus.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
if err := fullStatus.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
// sendTerminalNotification sends a terminal notification message
// to inform the user that the NetBird connection session has expired.
func sendTerminalNotification() error {

View File

@@ -11,6 +11,8 @@ import (
"strings"
"time"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/client/anonymize"
@@ -18,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/version"
"golang.org/x/exp/maps"
)
type PeerStateDetailOutput struct {
@@ -101,9 +104,7 @@ type OutputOverview struct {
ProfileName string `json:"profileName" yaml:"profileName"`
}
func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
pbFullStatus := resp.GetFullStatus()
func ConvertToStatusOutputOverview(pbFullStatus *proto.FullStatus, anon bool, daemonVersion string, statusFilter string, prefixNamesFilter []string, prefixNamesFilterMap map[string]struct{}, ipsFilter map[string]struct{}, connectionTypeFilter string, profName string) OutputOverview {
managementState := pbFullStatus.GetManagementState()
managementOverview := ManagementStateOutput{
URL: managementState.GetURL(),
@@ -119,12 +120,12 @@ func ConvertToStatusOutputOverview(resp *proto.StatusResponse, anon bool, status
}
relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
peersOverview := mapPeers(pbFullStatus.GetPeers(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilter, connectionTypeFilter)
overview := OutputOverview{
Peers: peersOverview,
CliVersion: version.NetbirdVersion(),
DaemonVersion: resp.GetDaemonVersion(),
DaemonVersion: daemonVersion,
ManagementState: managementOverview,
SignalState: signalOverview,
Relays: relayOverview,
@@ -458,6 +459,93 @@ func ParseToFullDetailSummary(overview OutputOverview) string {
)
}
func ToProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
pbFullStatus := proto.FullStatus{
ManagementState: &proto.ManagementState{},
SignalState: &proto.SignalState{},
LocalPeerState: &proto.LocalPeerState{},
Peers: []*proto.PeerState{},
}
pbFullStatus.ManagementState.URL = fullStatus.ManagementState.URL
pbFullStatus.ManagementState.Connected = fullStatus.ManagementState.Connected
if err := fullStatus.ManagementState.Error; err != nil {
pbFullStatus.ManagementState.Error = err.Error()
}
pbFullStatus.SignalState.URL = fullStatus.SignalState.URL
pbFullStatus.SignalState.Connected = fullStatus.SignalState.Connected
if err := fullStatus.SignalState.Error; err != nil {
pbFullStatus.SignalState.Error = err.Error()
}
pbFullStatus.LocalPeerState.IP = fullStatus.LocalPeerState.IP
pbFullStatus.LocalPeerState.PubKey = fullStatus.LocalPeerState.PubKey
pbFullStatus.LocalPeerState.KernelInterface = fullStatus.LocalPeerState.KernelInterface
pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN
pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive
pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled
pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes)
pbFullStatus.NumberOfForwardingRules = int32(fullStatus.NumOfForwardingRules)
pbFullStatus.LazyConnectionEnabled = fullStatus.LazyConnectionEnabled
for _, peerState := range fullStatus.Peers {
pbPeerState := &proto.PeerState{
IP: peerState.IP,
PubKey: peerState.PubKey,
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
BytesTx: peerState.BytesTx,
RosenpassEnabled: peerState.RosenpassEnabled,
Networks: maps.Keys(peerState.GetRoutes()),
Latency: durationpb.New(peerState.Latency),
}
pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState)
}
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()
}
pbFullStatus.Relays = append(pbFullStatus.Relays, pbRelayState)
}
for _, dnsState := range fullStatus.NSGroupStates {
var err string
if dnsState.Error != nil {
err = dnsState.Error.Error()
}
var servers []string
for _, server := range dnsState.Servers {
servers = append(servers, server.String())
}
pbDnsState := &proto.NSGroupState{
Servers: servers,
Domains: dnsState.Domains,
Enabled: dnsState.Enabled,
Error: err,
}
pbFullStatus.DnsServers = append(pbFullStatus.DnsServers, pbDnsState)
}
return &pbFullStatus
}
func parsePeers(peers PeersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
var (
peersString = ""
@@ -737,3 +825,4 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *OutputOverview) {
}
}
}

View File

@@ -234,7 +234,7 @@ var overview = OutputOverview{
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := ConvertToStatusOutputOverview(resp, false, "", nil, nil, nil, "", "")
convertedResult := ConvertToStatusOutputOverview(resp.GetFullStatus(), false, resp.GetDaemonVersion(), "", nil, nil, nil, "", "")
assert.Equal(t, overview, convertedResult)
}

View File

@@ -433,7 +433,7 @@ func (s *serviceClient) collectDebugData(
var postUpStatusOutput string
if postUpStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "")
overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus.GetFullStatus(), params.anonymize, postUpStatus.GetDaemonVersion(), "", nil, nil, nil, "", "")
postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
@@ -450,7 +450,7 @@ func (s *serviceClient) collectDebugData(
var preDownStatusOutput string
if preDownStatus != nil {
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "")
overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus.GetFullStatus(), params.anonymize, preDownStatus.GetDaemonVersion(), "", nil, nil, nil, "", "")
preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview)
}
headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s",
@@ -581,7 +581,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa
var statusOutput string
if statusResp != nil {
overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "")
overview := nbstatus.ConvertToStatusOutputOverview(statusResp.GetFullStatus(), anonymize, statusResp.GetDaemonVersion(), "", nil, nil, nil, "", "")
statusOutput = nbstatus.ParseToFullDetailSummary(overview)
}

View File

@@ -163,11 +163,11 @@ func (s *GRPCServer) Job(srv proto.ManagementService_JobServer) error {
}
// Start background response handler
s.startResponseReceiver(ctx, accountID, srv)
s.startResponseReceiver(ctx, srv)
// Prepare per-peer state
updates := s.jobManager.CreateJobChannel(peer.ID)
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
updates := s.jobManager.CreateJobChannel(ctx, accountID, peer.ID)
log.WithContext(ctx).Debugf("Job: took %v", time.Since(reqStart))
// Main loop: forward jobs to client
return s.sendJobsLoop(ctx, accountID, peerKey, peer, updates, srv)
@@ -262,7 +262,7 @@ func (s *GRPCServer) handleHandshake(ctx context.Context, srv proto.ManagementSe
return peerKey, nil
}
func (s *GRPCServer) startResponseReceiver(ctx context.Context, accountID string, srv proto.ManagementService_JobServer) {
func (s *GRPCServer) startResponseReceiver(ctx context.Context, srv proto.ManagementService_JobServer) {
go func() {
for {
msg, err := srv.Recv()
@@ -280,7 +280,7 @@ func (s *GRPCServer) startResponseReceiver(ctx context.Context, accountID string
continue
}
if err := s.jobManager.HandleResponse(ctx, accountID, jobResp); err != nil {
if err := s.jobManager.HandleResponse(ctx, jobResp); err != nil {
log.WithContext(ctx).Errorf("handle job response failed: %v", err)
}

View File

@@ -8,7 +8,9 @@ import (
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
log "github.com/sirupsen/logrus"
)
const jobChannelBuffer = 100
@@ -17,7 +19,6 @@ type JobEvent struct {
PeerID string
Request *proto.JobRequest
Response *proto.JobResponse
Done chan struct{} // closed when response arrives
}
type JobManager struct {
@@ -42,9 +43,11 @@ func NewJobManager(metrics telemetry.AppMetrics, store store.Store) *JobManager
}
// CreateJobChannel creates or replaces a channel for a peer
func (jm *JobManager) CreateJobChannel(peerID string) chan *JobEvent {
// TODO: all pending jobs stored in db for this peer should be failed
// jm.Store.MarkPendingJobsAsFailed(peerID)
func (jm *JobManager) CreateJobChannel(ctx context.Context, accountID, peerID string) chan *JobEvent {
// all pending jobs stored in db for this peer should be failed
if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, peerID, "Pending job cleanup: marked as failed automatically due to being stuck too long"); err != nil {
log.WithContext(ctx).Error(err.Error())
}
jm.mu.Lock()
defer jm.mu.Unlock()
@@ -71,7 +74,6 @@ func (jm *JobManager) SendJob(ctx context.Context, accountID, peerID string, req
event := &JobEvent{
PeerID: peerID,
Request: req,
Done: make(chan struct{}),
}
jm.mu.Lock()
@@ -80,14 +82,6 @@ func (jm *JobManager) SendJob(ctx context.Context, accountID, peerID string, req
select {
case ch <- event:
case <-time.After(5 * time.Second):
jm.cleanup(ctx, accountID, string(req.ID), "timed out")
return fmt.Errorf("job channel full for peer %s", peerID)
}
select {
case <-event.Done:
return nil
case <-time.After(jm.responseWait):
jm.cleanup(ctx, accountID, string(req.ID), "timed out")
return fmt.Errorf("job %s timed out", req.ID)
@@ -95,25 +89,32 @@ func (jm *JobManager) SendJob(ctx context.Context, accountID, peerID string, req
jm.cleanup(ctx, accountID, string(req.ID), ctx.Err().Error())
return ctx.Err()
}
return nil
}
// HandleResponse marks a job as finished and moves it to completed
func (jm *JobManager) HandleResponse(ctx context.Context, accountID string, resp *proto.JobResponse) error {
func (jm *JobManager) HandleResponse(ctx context.Context, resp *proto.JobResponse) error {
jm.mu.Lock()
defer jm.mu.Unlock()
event, ok := jm.pending[string(resp.ID)]
jobID := string(resp.ID)
event, ok := jm.pending[jobID]
if !ok {
return fmt.Errorf("job %s not found", resp.ID)
return fmt.Errorf("job %s not found", jobID)
}
var job types.Job
if err := job.ApplyResponse(resp); err != nil {
return fmt.Errorf("invalid job response: %v", err)
}
//update or create the store for job response
err := jm.Store.CompletePeerJob(ctx, &job)
if err == nil {
event.Response = resp
}
event.Response = resp
//TODO: update the store for job response
// jm.store.CompleteJob(ctx,accountID, string(resp.GetID()), string(resp.GetResult()),string(resp.GetReason()))
close(event.Done)
delete(jm.pending, string(resp.ID))
return nil
delete(jm.pending, jobID)
return err
}
// CloseChannel closes a peers channel and cleans up its jobs
@@ -130,7 +131,9 @@ func (jm *JobManager) CloseChannel(ctx context.Context, accountID, peerID string
for jobID, ev := range jm.pending {
if ev.PeerID == peerID {
// if the client disconnect and there is pending job then marke it as failed
// jm.store.CompleteJob(ctx,accountID, jobID,"", "Time out ")
if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, peerID, "Time out peer disconnected"); err != nil {
log.WithContext(ctx).Errorf(err.Error())
}
delete(jm.pending, jobID)
}
}
@@ -142,8 +145,29 @@ func (jm *JobManager) cleanup(ctx context.Context, accountID, jobID string, reas
defer jm.mu.Unlock()
if ev, ok := jm.pending[jobID]; ok {
close(ev.Done)
// jm.store.CompleteJob(ctx, accountID, jobID, "", reason)
if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, ev.PeerID, reason); err != nil {
log.WithContext(ctx).Errorf(err.Error())
}
delete(jm.pending, jobID)
}
}
func (jm *JobManager) IsPeerConnected(peerID string) bool {
jm.mu.RLock()
defer jm.mu.RUnlock()
_, ok := jm.jobChannels[peerID]
return ok
}
func (jm *JobManager) IsPeerHasPendingJobs(peerID string) bool {
jm.mu.RLock()
defer jm.mu.RUnlock()
for _, ev := range jm.pending {
if ev.PeerID == peerID {
return true
}
}
return false
}

View File

@@ -353,22 +353,24 @@ func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, p
}
// check if peer connected
// todo: implement jobManager.IsPeerConnected
// if !am.jobManager.IsPeerConnected(ctx, peerID) {
// return status.NewJobFailedError("peer not connected")
// }
if !am.jobManager.IsPeerConnected(peerID) {
return status.Errorf(status.BadRequest, "peer not connected")
}
// check if already has pending jobs
// todo: implement jobManager.GetPendingJobsByPeerID
// if pending := am.jobManager.GetPendingJobsByPeerID(ctx, peerID); len(pending) > 0 {
// return status.NewJobAlreadyPendingError(peerID)
// }
if am.jobManager.IsPeerHasPendingJobs(peerID) {
return status.Errorf(status.BadRequest, "peer already hase pending job")
}
jobStream, err := job.ToStreamJobRequest()
if err != nil {
return status.Errorf(status.BadRequest, "invalid job request %v", err)
}
// try sending job first
// todo: implement am.jobManager.SendJob
// if err := am.jobManager.SendJob(ctx, peerID, job); err != nil {
// return status.NewJobFailedError(fmt.Sprintf("failed to send job: %v", err))
// }
if err := am.jobManager.SendJob(ctx, accountID, peerID, jobStream); err != nil {
return status.Errorf(status.Internal, "failed to send job: %v", err)
}
var peer *nbpeer.Peer
var eventsToStore func()

View File

@@ -136,18 +136,35 @@ func (s *SqlStore) CreatePeerJob(ctx context.Context, job *types.Job) error {
return nil
}
// job was pending for too long and has been cancelled
// todo call it when we first start the jobChannel to make sure no stuck jobs
func (s *SqlStore) MarkPendingJobsAsFailed(ctx context.Context, peerID string) error {
now := time.Now().UTC()
return s.db.
func (s *SqlStore) CompletePeerJob(ctx context.Context, job *types.Job) error {
result := s.db.
Model(&types.Job{}).
Where("peer_id = ? AND status = ?", types.JobStatusPending, peerID).
Updates(map[string]any{
"status": types.JobStatusFailed,
"failed_reason": "Pending job cleanup: marked as failed automatically due to being stuck too long",
"completed_at": now,
}).Error
Where(idQueryCondition, job.ID).
Updates(job)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to update job in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to create job in store")
}
return nil
}
// job was pending for too long and has been cancelled
func (s *SqlStore) MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error {
now := time.Now().UTC()
result := s.db.
Model(&types.Job{}).
Where(accountAndPeerIDQueryCondition+"AND status = ?", accountID, peerID, types.JobStatusPending).
Updates(types.Job{
Status: types.JobStatusFailed,
FailedReason: reason,
CompletedAt: &now,
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark pending jobs as Failed job in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark pending job as Failed in store")
}
return nil
}
// GetJobByID fetches job by ID
@@ -159,7 +176,11 @@ func (s *SqlStore) GetPeerJobByID(ctx context.Context, accountID, jobID string)
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "job %s not found", jobID)
}
return &job, err
if err != nil {
log.WithContext(ctx).Errorf("failed to fetch job from store: %s", err)
return nil, err
}
return &job, nil
}
// get all jobs
@@ -171,34 +192,13 @@ func (s *SqlStore) GetPeerJobs(ctx context.Context, accountID, peerID string) ([
Find(&jobs).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to fetch jobs from store: %s", err)
return nil, err
}
return jobs, nil
}
func (s *SqlStore) CompletePeerJob(accountID, jobID, result, failedReason string) error {
now := time.Now().UTC()
updates := map[string]any{
"completed_at": now,
}
if result != "" && failedReason == "" {
updates["status"] = types.JobStatusSucceeded
updates["result"] = result
updates["failed_reason"] = ""
} else {
updates["status"] = types.JobStatusFailed
updates["failed_reason"] = failedReason
}
return s.db.
Model(&types.Job{}).
Where(accountAndIDQueryCondition, accountID, jobID).
Updates(updates).Error
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring global lock")

View File

@@ -206,10 +206,10 @@ type Store interface {
MarkAccountPrimary(ctx context.Context, accountID string) error
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
CreatePeerJob(ctx context.Context, job *types.Job) error
CompletePeerJob(accountID, jobID, result, failedReason string) error
CompletePeerJob(ctx context.Context, job *types.Job) error
GetPeerJobByID(ctx context.Context, accountID, jobID string) (*types.Job, error)
GetPeerJobs(ctx context.Context, accountID, peerID string) ([]*types.Job, error)
MarkPendingJobsAsFailed(ctx context.Context, peerID string) error
MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
}
const (

View File

@@ -7,6 +7,7 @@ import (
"github.com/google/uuid"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -135,15 +136,15 @@ func validateAndBuildBundleParams(req api.WorkloadRequest, workload *Workload) e
if err != nil {
return fmt.Errorf("invalid parameters for bundle job")
}
// validate bundle_for_time <= 5 minutes
if bundle.Parameters.BundleForTime < 0 || bundle.Parameters.BundleForTime > 5 {
return fmt.Errorf("bundle_for_time must be between 0 and 5, got %d", bundle.Parameters.BundleForTime)
// validate bundle_for_time <= 5 minutes if BundleFor is enabled
if bundle.Parameters.BundleFor && bundle.Parameters.BundleForTime < 1 || bundle.Parameters.BundleForTime > 5 {
return fmt.Errorf("bundle_for_time must be between 1 and 5, got %d", bundle.Parameters.BundleForTime)
}
// validate log-file-count ≥ 1 and ≤ 1000
if bundle.Parameters.LogFileCount < 1 || bundle.Parameters.LogFileCount > 1000 {
return fmt.Errorf("log-file-count must be between 1 and 1000, got %d", bundle.Parameters.LogFileCount)
}
workload.Parameters, err = json.Marshal(bundle.Parameters)
if err != nil {
return fmt.Errorf("failed to marshal workload parameters: %w", err)
@@ -153,3 +154,65 @@ func validateAndBuildBundleParams(req api.WorkloadRequest, workload *Workload) e
return nil
}
// ApplyResponse validates and maps a proto.JobResponse into the Job fields.
func (j *Job) ApplyResponse(resp *proto.JobResponse) error {
if resp == nil {
return nil
}
j.ID = string(resp.ID)
now := time.Now().UTC()
j.CompletedAt = &now
switch resp.Status {
case proto.JobStatus_succeeded:
j.Status = JobStatusSucceeded
case proto.JobStatus_failed:
j.Status = JobStatusFailed
default:
j.Status = JobStatusPending
}
if len(resp.Reason) > 0 {
j.FailedReason = string(resp.Reason)
}
// Handle workload results (oneof)
var err error
switch r := resp.WorkloadResults.(type) {
case *proto.JobResponse_Bundle:
if j.Workload.Result, err = json.Marshal(r.Bundle); err != nil {
return fmt.Errorf("failed to marshal workload results: %w", err)
}
default:
return fmt.Errorf("unsupported workload response type: %T", r)
}
return nil
}
func (j *Job) ToStreamJobRequest() (*proto.JobRequest, error) {
switch j.Workload.Type {
case JobTypeBundle:
return j.buildStreamBundleResponse()
default:
return nil, status.Errorf(status.InvalidArgument, "unknown job type: %v", j.Workload.Type)
}
}
func (j *Job) buildStreamBundleResponse() (*proto.JobRequest, error) {
var p api.BundleParameters
if err := json.Unmarshal(j.Workload.Parameters, &p); err != nil {
return nil, fmt.Errorf("invalid parameters for bundle job: %w", err)
}
return &proto.JobRequest{
ID: []byte(j.ID),
WorkloadParameters: &proto.JobRequest_Bundle{
Bundle: &proto.BundleParameters{
BundleFor: p.BundleFor,
BundleForTime: int64(p.BundleForTime),
LogFileCount: int32(p.LogFileCount),
Anonymize: p.Anonymize,
},
},
}, nil
}

View File

@@ -180,13 +180,19 @@ func (c *GrpcClient) handleJobStream(
// Main loop: receive, process, respond
for {
jobReq, err := c.receiveJobRequest(ctx, stream, serverPubKey)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
log.WithContext(ctx).Info("job stream closed by server")
if err != nil && err != io.EOF {
c.notifyDisconnected(err)
s, _ := gstatus.FromError(err)
switch s.Code() {
case codes.PermissionDenied:
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
default:
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
log.WithContext(ctx).Errorf("error receiving job request: %v", err)
return err
}
if jobReq == nil || len(jobReq.ID) == 0 {
@@ -298,7 +304,7 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
// blocking until error
err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler)
if err != nil {
if err != nil && err != io.EOF {
c.notifyDisconnected(err)
s, _ := gstatus.FromError(err)
switch s.Code() {