Compare commits

..

3 Commits

Author SHA1 Message Date
aliamerj
b7cd2ee252 just testing 2025-09-01 14:05:19 +03:00
Ali Amer
3f6d95552f implement remote debug api (#4418)
fix lint

clean up

fix MarkPendingJobsAsFailed

apply feedbacks 1

fix typo

change api and apply new schema

fix lint

fix api object

clean switch case

apply feedback 2

fix error handle in create job

get rid of any/interface type in job database

fix sonar issue

use RawJson for both parameters and results

running go mod tidy

update package

fix 1

update codegen

fix code-gen

fix snyk

fix snyk hopefully
2025-08-29 18:00:40 +03:00
Ali Amer
d4ac7f8df9 [management/client] create job channel between management and client (#4367)
* new bi-directional stream for jobs

* create bidirectional job channel to send requests from the server and receive responses from the client

* fix tests

* fix lint and close bug

* fix lint

* clean up & fix close of closed channel

* add nolint:staticcheck

* remove some redundant code from the job channel PR since this one is a cleaner rewrite

* cleanup removes a pending job safely

* change proto

* rename to jobRequest

* apply feedback 1

* apply feedback 2

* fix typo

* apply feedback 3

* apply last feedback
2025-08-28 16:49:09 +03:00
30 changed files with 2107 additions and 915 deletions

View File

@@ -86,6 +86,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil) peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
jobManager := mgmt.NewJobManager(nil, store)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, nil return nil, nil
@@ -106,13 +107,13 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
Return(&types.Settings{}, nil). Return(&types.Settings{}, nil).
AnyTimes() AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{}) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, nil, nil, &mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -22,6 +22,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
@@ -53,6 +55,7 @@ import (
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -62,7 +65,9 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
sProto "github.com/netbirdio/netbird/shared/signal/proto" sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/upload-server/types"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"google.golang.org/grpc/status"
) )
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -125,6 +130,8 @@ type EngineConfig struct {
BlockInbound bool BlockInbound bool
LazyConnectionEnabled bool LazyConnectionEnabled bool
peerDaemonAddr string
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -461,6 +468,7 @@ func (e *Engine) Start() error {
e.receiveSignalEvents() e.receiveSignalEvents()
e.receiveManagementEvents() e.receiveManagementEvents()
e.receiveJobEvents()
// starting network monitor at the very last to avoid disruptions // starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor() e.startNetworkMonitor()
@@ -886,6 +894,101 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return nil return nil
} }
func (e *Engine) getPeerClient(addr string) (*grpc.ClientConn, error) {
conn, err := grpc.NewClient(
strings.TrimPrefix(addr, "tcp://"),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
return conn, nil
}
func (e *Engine) receiveJobEvents() {
go func() {
err := e.mgmClient.Job(e.ctx, func(msg *mgmProto.JobRequest) *mgmProto.JobResponse {
// Simple test handler — replace with real logic
log.Infof("Received job request: %+v", msg)
// TODO: trigger local debug bundle or other job
return &mgmProto.JobResponse{
ID: msg.ID,
WorkloadResults: &mgmProto.JobResponse_Bundle{
Bundle: &mgmProto.BundleResult{
UploadKey: "upload-key",
},
},
}
})
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return
}
log.Debugf("stopped receiving jobs from Management Service")
}()
log.Debugf("connecting to Management Service jobs stream")
}
func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (string, error) {
conn, err := e.getPeerClient("unix:///var/run/netbird.sock")
if err != nil {
return "", err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
}()
statusOutput, err := e.getStatusOutput(params.Anonymize)
if err != nil {
return "", err
}
request := &cProto.DebugBundleRequest{
Anonymize: params.Anonymize,
SystemInfo: true,
Status: statusOutput,
LogFileCount: uint32(params.LogFileCount),
UploadURL: types.DefaultBundleURL,
}
service := cProto.NewDaemonServiceClient(conn)
resp, err := service.DebugBundle(e.clientCtx, request)
if err != nil {
return "", fmt.Errorf("failed to bundle debug: " + status.Convert(err).Message())
}
if resp.GetUploadFailureReason() != "" {
return "", fmt.Errorf("upload failed: " + resp.GetUploadFailureReason())
}
return resp.GetUploadedKey(), nil
}
func (e *Engine) getStatusOutput(anon bool) (string, error) {
conn, err := e.getPeerClient("unix:///var/run/netbird.sock")
if err != nil {
return "", err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf("Failed to close connection: %v", err)
}
}()
statusResp, err := cProto.NewDaemonServiceClient(conn).Status(e.clientCtx, &cProto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
if err != nil {
return "", fmt.Errorf("status failed: %v", status.Convert(err).Message())
}
return nbstatus.ParseToFullDetailSummary(
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""),
), nil
}
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service // receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it. // E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() { func (e *Engine) receiveManagementEvents() {

View File

@@ -1544,6 +1544,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil) peersUpdateManager := server.NewPeersUpdateManager(nil)
jobManager := server.NewJobManager(nil, store)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@@ -1568,13 +1569,13 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -290,6 +290,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
peersUpdateManager := server.NewPeersUpdateManager(nil) peersUpdateManager := server.NewPeersUpdateManager(nil)
jobManager := server.NewJobManager(nil, store)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@@ -305,13 +306,13 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock := permissions.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{}) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, nil, nil, &server.MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
} }
gRPCAPIHandler := grpc.NewServer(gRPCOpts...) gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator()) srv, err := server.NewServer(context.Background(), s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator())
if err != nil { if err != nil {
log.Fatalf("failed to create management server: %v", err) log.Fatalf("failed to create management server: %v", err)
} }

View File

@@ -18,6 +18,12 @@ func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager {
}) })
} }
func (s *BaseServer) JobManager() *server.JobManager {
return Create(s, func() *server.JobManager {
return server.NewJobManager(s.Metrics(), s.Store())
})
}
func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator { func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator {
return Create(s, func() integrated_validator.IntegratedValidator { return Create(s, func() integrated_validator.IntegratedValidator {
integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore()) integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.EventStore())

View File

@@ -60,7 +60,7 @@ func (s *BaseServer) PeersManager() peers.Manager {
func (s *BaseServer) AccountManager() account.Manager { func (s *BaseServer) AccountManager() account.Manager {
return Create(s, func() account.Manager { return Create(s, func() account.Manager {
accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, accountManager, err := server.BuildManager(context.Background(), s.Store(), s.PeersUpdateManager(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain,
s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy) s.dnsDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
if err != nil { if err != nil {
log.Fatalf("failed to create account manager: %v", err) log.Fatalf("failed to create account manager: %v", err)

View File

@@ -68,6 +68,7 @@ type DefaultAccountManager struct {
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{} cacheLoading map[string]chan struct{}
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
jobManager *JobManager
idpManager idp.Manager idpManager idp.Manager
cacheManager *nbcache.AccountUserDataCache cacheManager *nbcache.AccountUserDataCache
externalCacheManager nbcache.UserDataCache externalCacheManager nbcache.UserDataCache
@@ -174,6 +175,7 @@ func BuildManager(
ctx context.Context, ctx context.Context,
store store.Store, store store.Store,
peersUpdateManager *PeersUpdateManager, peersUpdateManager *PeersUpdateManager,
jobManager *JobManager,
idpManager idp.Manager, idpManager idp.Manager,
singleAccountModeDomain string, singleAccountModeDomain string,
dnsDomain string, dnsDomain string,
@@ -196,6 +198,7 @@ func BuildManager(
Store: store, Store: store,
geo: geo, geo: geo,
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
jobManager: jobManager,
idpManager: idpManager, idpManager: idpManager,
ctx: context.Background(), ctx: context.Background(),
cacheMux: sync.Mutex{}, cacheMux: sync.Mutex{},

View File

@@ -2891,7 +2891,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), NewJobManager(nil, store), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -219,7 +219,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
// return empty extra settings for expected calls to UpdateAccountPeers // return empty extra settings for expected calls to UpdateAccountPeers
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), NewJobManager(nil, store), nil, "", "netbird.test", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createDNSStore(t *testing.T) (store.Store, error) { func createDNSStore(t *testing.T) (store.Store, error) {

View File

@@ -2,7 +2,9 @@ package server
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io"
"net" "net"
"net/netip" "net/netip"
"strings" "strings"
@@ -45,6 +47,7 @@ type GRPCServer struct {
wgKey wgtypes.Key wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer proto.UnimplementedManagementServiceServer
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
jobManager *JobManager
config *nbconfig.Config config *nbconfig.Config
secretsManager SecretsManager secretsManager SecretsManager
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
@@ -61,6 +64,7 @@ func NewServer(
accountManager account.Manager, accountManager account.Manager,
settingsManager settings.Manager, settingsManager settings.Manager,
peersUpdateManager *PeersUpdateManager, peersUpdateManager *PeersUpdateManager,
jobManager *JobManager,
secretsManager SecretsManager, secretsManager SecretsManager,
appMetrics telemetry.AppMetrics, appMetrics telemetry.AppMetrics,
ephemeralManager *EphemeralManager, ephemeralManager *EphemeralManager,
@@ -74,10 +78,9 @@ func NewServer(
if appMetrics != nil { if appMetrics != nil {
// update gauge based on number of connected peers which is equal to open gRPC streams // update gauge based on number of connected peers which is equal to open gRPC streams
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 { if err := appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
return int64(len(peersUpdateManager.peerChannels)) return int64(len(peersUpdateManager.peerChannels))
}) }); err != nil {
if err != nil {
return nil, err return nil, err
} }
} }
@@ -86,6 +89,7 @@ func NewServer(
wgKey: key, wgKey: key,
// peerKey -> event channel // peerKey -> event channel
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
jobManager: jobManager,
accountManager: accountManager, accountManager: accountManager,
settingsManager: settingsManager, settingsManager: settingsManager,
config: config, config: config,
@@ -132,6 +136,46 @@ func getRealIP(ctx context.Context) net.IP {
return nil return nil
} }
func (s *GRPCServer) Job(srv proto.ManagementService_JobServer) error {
reqStart := time.Now()
ctx := srv.Context()
peerKey, err := s.handleHandshake(ctx, srv)
if err != nil {
return err
}
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
return status.Errorf(codes.PermissionDenied, "peer is not registered")
}
return err
}
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
if err != nil {
return status.Errorf(codes.Unauthenticated, "peer is not registered")
}
// Start background response handler
s.startResponseReceiver(ctx, srv)
// Prepare per-peer state
updates, err := s.jobManager.CreateJobChannel(ctx, accountID, peer.ID)
if err != nil {
return status.Errorf(codes.Internal, err.Error())
}
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
// Main loop: forward jobs to client
return s.sendJobsLoop(ctx, accountID, peerKey, peer, updates, srv)
}
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account) // notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
@@ -147,7 +191,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
if err != nil { if err != nil {
return err return err
} }
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
@@ -171,7 +214,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck // nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
realIP := getRealIP(ctx) realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
@@ -191,10 +233,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
return err return err
} }
// Prepare per-peer state
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
s.ephemeralManager.OnPeerConnected(ctx, peer) s.ephemeralManager.OnPeerConnected(ctx, peer)
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
if s.appMetrics != nil { if s.appMetrics != nil {
@@ -209,6 +250,76 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
} }
func (s *GRPCServer) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) {
hello, err := srv.Recv()
if err != nil {
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "missing hello: %v", err)
}
jobReq := &proto.JobRequest{}
peerKey, err := s.parseRequest(ctx, hello, jobReq)
if err != nil {
return wgtypes.Key{}, err
}
return peerKey, nil
}
func (s *GRPCServer) startResponseReceiver(ctx context.Context, srv proto.ManagementService_JobServer) {
go func() {
for {
msg, err := srv.Recv()
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
return
}
log.WithContext(ctx).Warnf("recv job response error: %v", err)
return
}
jobResp := &proto.JobResponse{}
if _, err := s.parseRequest(ctx, msg, jobResp); err != nil {
log.WithContext(ctx).Warnf("invalid job response: %v", err)
continue
}
if err := s.jobManager.HandleResponse(ctx, jobResp); err != nil {
log.WithContext(ctx).Errorf("handle job response failed: %v", err)
}
}
}()
}
func (s *GRPCServer) sendJobsLoop(
ctx context.Context,
accountID string,
peerKey wgtypes.Key,
peer *nbpeer.Peer,
updates <-chan *JobEvent,
srv proto.ManagementService_JobServer,
) error {
for {
select {
case job, open := <-updates:
if !open {
log.WithContext(ctx).Debugf("jobs channel for peer %s was closed", peerKey.String())
s.jobManager.CloseChannel(ctx, accountID, peer.ID)
return nil
}
if err := s.sendJob(ctx, accountID, peerKey, peer, job, srv); err != nil {
s.jobManager.CloseChannel(ctx, accountID, peer.ID)
log.WithContext(ctx).Warnf("send job failed: %v", err)
}
case <-ctx.Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.jobManager.CloseChannel(ctx, accountID, peer.ID)
return ctx.Err()
}
}
}
// handleUpdates sends updates to the connected peer until the updates channel is closed. // handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String())
@@ -226,7 +337,6 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
return nil return nil
} }
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
return err return err
} }
@@ -249,7 +359,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
s.cancelPeerRoutines(ctx, accountID, peer) s.cancelPeerRoutines(ctx, accountID, peer)
return status.Errorf(codes.Internal, "failed processing update message") return status.Errorf(codes.Internal, "failed processing update message")
} }
err = srv.SendMsg(&proto.EncryptedMessage{ err = srv.Send(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(), WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp, Body: encryptedResp,
}) })
@@ -261,6 +371,27 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
return nil return nil
} }
// sendJob encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendJob(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, job *JobEvent, srv proto.ManagementService_JobServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, job.Request)
if err != nil {
log.WithContext(ctx).Errorf("failed to encrypt job for peer %s: %v", peerKey.String(), err)
s.jobManager.CloseChannel(ctx, accountID, peer.ID)
return nil
}
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
if err != nil {
s.jobManager.CloseChannel(ctx, accountID, peer.ID)
return status.Errorf(codes.Internal, "failed sending job message")
}
log.WithContext(ctx).Debugf("sent an job to peer %s", peerKey.String())
return nil
}
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
unlock := s.acquirePeerLockByUID(ctx, peer.Key) unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock() defer unlock()
@@ -729,8 +860,8 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
var err error var err error
var turnToken *Token var turnToken *Token
if s.config.TURNConfig != nil && s.config.TURNConfig.TimeBasedCredentials { if s.config.TURNConfig != nil && s.config.TURNConfig.TimeBasedCredentials {
turnToken, err = s.secretsManager.GenerateTurnToken() turnToken, err = s.secretsManager.GenerateTurnToken()
if err != nil { if err != nil {

View File

@@ -119,6 +119,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
} }
peersUpdateManager := server.NewPeersUpdateManager(nil) peersUpdateManager := server.NewPeersUpdateManager(nil)
jobManager := server.NewJobManager(nil, store)
updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId) updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId)
done := make(chan struct{}) done := make(chan struct{})
if validateUpdate { if validateUpdate {
@@ -138,7 +139,7 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve
userManager := users.NewManager(store) userManager := users.NewManager(store)
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) am, err := server.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
if err != nil { if err != nil {
t.Fatalf("Failed to create manager: %v", err) t.Fatalf("Failed to create manager: %v", err)
} }

View File

@@ -0,0 +1,223 @@
package server
import (
"context"
"fmt"
"sync"
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
log "github.com/sirupsen/logrus"
)
const jobChannelBuffer = 100
type JobEvent struct {
Job *types.Job
Request *proto.JobRequest
Response *proto.JobResponse
Done chan struct{} // closed when response arrives
StoreEvent func(meta map[string]any, peer *nbpeer.Peer)
}
type JobManager struct {
mu *sync.RWMutex
jobChannels map[string]chan *JobEvent // per-peer job streams
pending map[string]*JobEvent // jobID → event
responseWait time.Duration
metrics telemetry.AppMetrics
Store store.Store
}
func NewJobManager(metrics telemetry.AppMetrics, store store.Store) *JobManager {
return &JobManager{
jobChannels: make(map[string]chan *JobEvent),
pending: make(map[string]*JobEvent),
responseWait: 5 * time.Minute,
metrics: metrics,
mu: &sync.RWMutex{},
Store: store,
}
}
// CreateJobChannel creates or replaces a channel for a peer
func (jm *JobManager) CreateJobChannel(ctx context.Context, accountID, peerID string) (chan *JobEvent, error) {
// all pending jobs stored in db for this peer should be failed
if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, peerID, "Pending job cleanup: marked as failed automatically due to being stuck too long"); err != nil {
return nil, err
}
jm.mu.Lock()
defer jm.mu.Unlock()
if ch, ok := jm.jobChannels[peerID]; ok {
close(ch)
delete(jm.jobChannels, peerID)
}
ch := make(chan *JobEvent, jobChannelBuffer)
jm.jobChannels[peerID] = ch
return ch, nil
}
// SendJob sends a job to a peer and tracks it as pending
func (jm *JobManager) SendJob(ctx context.Context, job *types.Job, storeEvent func(meta map[string]any, peer *nbpeer.Peer)) error {
jm.mu.RLock()
ch, ok := jm.jobChannels[job.PeerID]
jm.mu.RUnlock()
if !ok {
return fmt.Errorf("peer %s has no channel", job.PeerID)
}
req, err := job.ToStreamJobRequest()
if err != nil {
return err
}
event := &JobEvent{
Job: job,
Request: req,
Done: make(chan struct{}),
StoreEvent: storeEvent,
}
jm.mu.Lock()
jm.pending[string(req.ID)] = event
jm.mu.Unlock()
select {
case ch <- event:
case <-time.After(5 * time.Second):
jm.cleanup(ctx, string(req.ID), "timed out")
return fmt.Errorf("job channel full for peer %s", job.PeerID)
}
select {
case <-event.Done:
return nil
case <-time.After(jm.responseWait):
jm.cleanup(ctx, string(req.ID), "timed out")
return fmt.Errorf("job %s timed out", req.ID)
case <-ctx.Done():
jm.cleanup(ctx, string(req.ID), ctx.Err().Error())
return ctx.Err()
}
}
// HandleResponse marks a job as finished and moves it to completed
func (jm *JobManager) HandleResponse(ctx context.Context, resp *proto.JobResponse) error {
jm.mu.Lock()
defer jm.mu.Unlock()
event, ok := jm.pending[string(resp.ID)]
if !ok {
return fmt.Errorf("job %s not found", resp.ID)
}
fmt.Printf("we got this %+v\n", resp)
//update or create the store for job response
err := jm.saveJob(ctx, event.Job, resp, event.StoreEvent)
if err == nil {
event.Response = resp
}
close(event.Done)
delete(jm.pending, string(resp.ID))
return err
}
// CloseChannel closes a peers channel and cleans up its jobs
func (jm *JobManager) CloseChannel(ctx context.Context, accountID, peerID string) {
jm.mu.Lock()
defer jm.mu.Unlock()
if ch, ok := jm.jobChannels[peerID]; ok {
close(ch)
jm.jobChannels[peerID] = nil
delete(jm.jobChannels, peerID)
}
for jobID, ev := range jm.pending {
if ev.Job.PeerID == peerID {
// if the client disconnect and there is pending job then marke it as failed
if err := jm.Store.MarkPendingJobsAsFailed(ctx, accountID, peerID, "Time out"); err != nil {
log.WithContext(ctx).Errorf(err.Error())
}
delete(jm.pending, jobID)
}
}
}
// cleanup removes a pending job safely
func (jm *JobManager) cleanup(ctx context.Context, jobID string, reason string) {
jm.mu.Lock()
defer jm.mu.Unlock()
if ev, ok := jm.pending[jobID]; ok {
close(ev.Done)
if err := jm.Store.MarkPendingJobsAsFailed(ctx, ev.Job.AccountID, ev.Job.PeerID, reason); err != nil {
log.WithContext(ctx).Errorf(err.Error())
}
delete(jm.pending, jobID)
}
}
func (jm *JobManager) IsPeerConnected(peerID string) bool {
jm.mu.RLock()
defer jm.mu.RUnlock()
_, ok := jm.jobChannels[peerID]
return ok
}
func (jm *JobManager) IsPeerHasPendingJobs(peerID string) bool {
jm.mu.RLock()
defer jm.mu.RUnlock()
for _, ev := range jm.pending {
if ev.Job.PeerID == peerID {
return true
}
}
return false
}
func (jm *JobManager) saveJob(ctx context.Context, job *types.Job, response *proto.JobResponse, StoreEvent func(meta map[string]any, peer *nbpeer.Peer)) error {
var peer *nbpeer.Peer
var err error
var eventsToStore func()
// persist job in DB only if send succeeded
err = jm.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, job.AccountID, job.PeerID)
if err != nil {
return err
}
if err := transaction.CreateOrUpdatePeerJob(ctx, job, response); err != nil {
return err
}
jobMeta := map[string]any{
"job_id": job.ID,
"for_peer_id": job.PeerID,
"job_type": job.Workload.Type,
"job_status": job.Status,
"job_workload": job.Workload,
}
eventsToStore = func() {
StoreEvent(jobMeta, peer)
}
return nil
})
if err != nil {
return err
}
eventsToStore()
return nil
}

View File

@@ -427,6 +427,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
} }
peersUpdateManager := NewPeersUpdateManager(nil) peersUpdateManager := NewPeersUpdateManager(nil)
jobManager := NewJobManager(nil, store)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck ctx := context.WithValue(context.Background(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
@@ -450,7 +451,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted", accountManager, err := BuildManager(ctx, store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil { if err != nil {
@@ -461,7 +462,7 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config
secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
ephemeralMgr := NewEphemeralManager(store, accountManager) ephemeralMgr := NewEphemeralManager(store, accountManager)
mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{}) mgmtServer, err := NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, ephemeralMgr, nil, MockIntegratedValidator{})
if err != nil { if err != nil {
return nil, nil, "", cleanup, err return nil, nil, "", cleanup, err
} }

View File

@@ -176,6 +176,7 @@ func startServer(
} }
peersUpdateManager := server.NewPeersUpdateManager(nil) peersUpdateManager := server.NewPeersUpdateManager(nil)
jobManager := server.NewJobManager(nil, str)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
@@ -202,6 +203,7 @@ func startServer(
context.Background(), context.Background(),
str, str,
peersUpdateManager, peersUpdateManager,
jobManager,
nil, nil,
"", "",
"netbird.selfhosted", "netbird.selfhosted",
@@ -226,6 +228,7 @@ func startServer(
accountManager, accountManager,
settingsMockManager, settingsMockManager,
peersUpdateManager, peersUpdateManager,
jobManager,
secretsManager, secretsManager,
nil, nil,
nil, nil,

View File

@@ -148,6 +148,7 @@ func (am *MockAccountManager) GetPeerJobByID(ctx context.Context, accountID, use
return nil, status.Errorf(codes.Unimplemented, "method CreateJob is not implemented") return nil, status.Errorf(codes.Unimplemented, "method CreateJob is not implemented")
} }
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
if am.SaveGroupFunc != nil { if am.SaveGroupFunc != nil {
return am.SaveGroupFunc(ctx, accountID, userID, group, true) return am.SaveGroupFunc(ctx, accountID, userID, group, true)

View File

@@ -785,7 +785,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), NewJobManager(nil, store), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createNSStore(t *testing.T) (store.Store, error) { func createNSStore(t *testing.T) (store.Store, error) {

View File

@@ -353,53 +353,21 @@ func (am *DefaultAccountManager) CreatePeerJob(ctx context.Context, accountID, p
} }
// check if peer connected // check if peer connected
// todo: implement jobManager.IsPeerConnected if !am.jobManager.IsPeerConnected(peerID) {
// if !am.jobManager.IsPeerConnected(ctx, peerID) { return status.Errorf(status.BadRequest, "peer not connected")
// return status.NewJobFailedError("peer not connected") }
// }
// check if already has pending jobs // check if already has pending jobs
// todo: implement jobManager.GetPendingJobsByPeerID if am.jobManager.IsPeerHasPendingJobs(peerID) {
// if pending := am.jobManager.GetPendingJobsByPeerID(ctx, peerID); len(pending) > 0 { return status.Errorf(status.BadRequest, "peer already hase pending job")
// return status.NewJobAlreadyPendingError(peerID) }
// }
// try sending job first // try sending job first
// todo: implement am.jobManager.SendJob if err := am.jobManager.SendJob(ctx, job, func(meta map[string]any, peer *nbpeer.Peer) {
// if err := am.jobManager.SendJob(ctx, peerID, job); err != nil { am.StoreEvent(ctx, userID, peer.ID, accountID, activity.JobCreatedByUser, meta)
// return status.NewJobFailedError(fmt.Sprintf("failed to send job: %v", err)) }); err != nil {
// } return status.Errorf(status.Internal, "failed to send job: %v", err)
var peer *nbpeer.Peer
var eventsToStore func()
// persist job in DB only if send succeeded
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return err
}
if err := transaction.CreatePeerJob(ctx, job); err != nil {
return err
}
jobMeta := map[string]any{
"job_id": job.ID,
"for_peer_id": job.PeerID,
"job_type": job.Workload.Type,
"job_status": job.Status,
"job_workload": job.Workload,
}
eventsToStore = func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.JobCreatedByUser, jobMeta)
}
return nil
})
if err != nil {
return err
} }
eventsToStore()
return nil return nil
} }
@@ -422,7 +390,7 @@ func (am *DefaultAccountManager) GetAllPeerJobs(ctx context.Context, accountID,
return []*types.Job{}, nil return []*types.Job{}, nil
} }
accountJobs, err := am.Store.GetPeerJobs(ctx, accountID, peerID) accountJobs, err := am.Store.GetPeerJobs(accountID, peerID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -449,7 +417,7 @@ func (am *DefaultAccountManager) GetPeerJobByID(ctx context.Context, accountID,
return &types.Job{}, nil return &types.Job{}, nil
} }
job, err := am.Store.GetPeerJobByID(ctx, accountID, jobID) job, err := am.Store.GetPeerJobByID(accountID, jobID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1274,7 +1274,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), NewJobManager(nil, s), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1354,7 +1354,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), NewJobManager(nil, s), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1502,7 +1502,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), NewJobManager(nil, s), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
@@ -1576,7 +1576,7 @@ func Test_LoginPeer(t *testing.T) {
AnyTimes() AnyTimes()
permissionsManager := permissions.NewManager(s) permissionsManager := permissions.NewManager(s)
am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), NewJobManager(nil, s), nil, "", "netbird.cloud", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
assert.NoError(t, err) assert.NoError(t, err)
existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"

View File

@@ -1284,7 +1284,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
permissionsManager := permissions.NewManager(store) permissionsManager := permissions.NewManager(store)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), NewJobManager(nil, store), nil, "", "netbird.selfhosted", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
} }
func createRouterStore(t *testing.T) (store.Store, error) { func createRouterStore(t *testing.T) (store.Store, error) {

View File

@@ -34,6 +34,7 @@ import (
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -127,8 +128,18 @@ func GetKeyQueryCondition(s *SqlStore) string {
} }
// SaveJob persists a job in DB // SaveJob persists a job in DB
func (s *SqlStore) CreatePeerJob(ctx context.Context, job *types.Job) error { func (s *SqlStore) CreateOrUpdatePeerJob(ctx context.Context, job *types.Job, jobResponse *proto.JobResponse) error {
result := s.db.Create(job) if err := job.ApplyResponse(jobResponse); err != nil {
return status.Errorf(status.Internal, err.Error())
}
fmt.Printf("new job or update %v\n", job)
result := s.db.
Model(&types.Job{}).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "id"}},
DoUpdates: clause.AssignmentColumns([]string{"completed_at", "status", "failed_reason", "workload_workload_type", "workload_parameters", "workload_result"}),
}).Create(job)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create job in store: %s", result.Error) log.WithContext(ctx).Errorf("failed to create job in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to create job in store") return status.Errorf(status.Internal, "failed to create job in store")
@@ -137,21 +148,25 @@ func (s *SqlStore) CreatePeerJob(ctx context.Context, job *types.Job) error {
} }
// job was pending for too long and has been cancelled // job was pending for too long and has been cancelled
// todo call it when we first start the jobChannel to make sure no stuck jobs func (s *SqlStore) MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error {
func (s *SqlStore) MarkPendingJobsAsFailed(ctx context.Context, peerID string) error {
now := time.Now().UTC() now := time.Now().UTC()
return s.db. result := s.db.
Model(&types.Job{}). Model(&types.Job{}).
Where("peer_id = ? AND status = ?", types.JobStatusPending, peerID). Where(accountAndPeerIDQueryCondition+"AND status = ?", accountID, peerID, types.JobStatusPending).
Updates(map[string]any{ Updates(types.Job{
"status": types.JobStatusFailed, Status: types.JobStatusFailed,
"failed_reason": "Pending job cleanup: marked as failed automatically due to being stuck too long", FailedReason: reason,
"completed_at": now, CompletedAt: &now,
}).Error })
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark pending jobs as Failed job in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark pending job as Failed in store")
}
return nil
} }
// GetJobByID fetches job by ID // GetJobByID fetches job by ID
func (s *SqlStore) GetPeerJobByID(ctx context.Context, accountID, jobID string) (*types.Job, error) { func (s *SqlStore) GetPeerJobByID(accountID, jobID string) (*types.Job, error) {
var job types.Job var job types.Job
err := s.db. err := s.db.
Where(accountAndIDQueryCondition, accountID, jobID). Where(accountAndIDQueryCondition, accountID, jobID).
@@ -163,7 +178,7 @@ func (s *SqlStore) GetPeerJobByID(ctx context.Context, accountID, jobID string)
} }
// get all jobs // get all jobs
func (s *SqlStore) GetPeerJobs(ctx context.Context, accountID, peerID string) ([]*types.Job, error) { func (s *SqlStore) GetPeerJobs(accountID, peerID string) ([]*types.Job, error) {
var jobs []*types.Job var jobs []*types.Job
err := s.db. err := s.db.
Where(accountAndPeerIDQueryCondition, accountID, peerID). Where(accountAndPeerIDQueryCondition, accountID, peerID).
@@ -177,28 +192,6 @@ func (s *SqlStore) GetPeerJobs(ctx context.Context, accountID, peerID string) ([
return jobs, nil return jobs, nil
} }
func (s *SqlStore) CompletePeerJob(accountID, jobID, result, failedReason string) error {
now := time.Now().UTC()
updates := map[string]any{
"completed_at": now,
}
if result != "" && failedReason == "" {
updates["status"] = types.JobStatusSucceeded
updates["result"] = result
updates["failed_reason"] = ""
} else {
updates["status"] = types.JobStatusFailed
updates["failed_reason"] = failedReason
}
return s.db.
Model(&types.Job{}).
Where(accountAndIDQueryCondition, accountID, jobID).
Updates(updates).Error
}
// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock
func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring global lock") log.WithContext(ctx).Tracef("acquiring global lock")

View File

@@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/management/server/migration" "github.com/netbirdio/netbird/management/server/migration"
@@ -205,11 +206,10 @@ type Store interface {
IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error)
MarkAccountPrimary(ctx context.Context, accountID string) error MarkAccountPrimary(ctx context.Context, accountID string) error
UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error
CreatePeerJob(ctx context.Context, job *types.Job) error GetPeerJobByID(accountID, jobID string) (*types.Job, error)
CompletePeerJob(accountID, jobID, result, failedReason string) error GetPeerJobs(accountID, peerID string) ([]*types.Job, error)
GetPeerJobByID(ctx context.Context, accountID, jobID string) (*types.Job, error) MarkPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error
GetPeerJobs(ctx context.Context, accountID, peerID string) ([]*types.Job, error) CreateOrUpdatePeerJob(ctx context.Context, job *types.Job, jobResponse *proto.JobResponse) error
MarkPendingJobsAsFailed(ctx context.Context, peerID string) error
} }
const ( const (

View File

@@ -7,6 +7,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -143,7 +144,7 @@ func validateAndBuildBundleParams(req api.WorkloadRequest, workload *Workload) e
if bundle.Parameters.LogFileCount < 1 || bundle.Parameters.LogFileCount > 1000 { if bundle.Parameters.LogFileCount < 1 || bundle.Parameters.LogFileCount > 1000 {
return fmt.Errorf("log-file-count must be between 1 and 1000, got %d", bundle.Parameters.LogFileCount) return fmt.Errorf("log-file-count must be between 1 and 1000, got %d", bundle.Parameters.LogFileCount)
} }
workload.Parameters, err = json.Marshal(bundle.Parameters) workload.Parameters, err = json.Marshal(bundle.Parameters)
if err != nil { if err != nil {
return fmt.Errorf("failed to marshal workload parameters: %w", err) return fmt.Errorf("failed to marshal workload parameters: %w", err)
@@ -153,3 +154,63 @@ func validateAndBuildBundleParams(req api.WorkloadRequest, workload *Workload) e
return nil return nil
} }
// ApplyResponse validates and maps a proto.JobResponse into the Job fields.
func (j *Job) ApplyResponse(resp *proto.JobResponse) error {
if resp == nil {
return nil
}
now := time.Now().UTC()
j.CompletedAt = &now
switch resp.Status {
case proto.JobStatus_succeeded:
j.Status = JobStatusSucceeded
case proto.JobStatus_failed:
j.Status = JobStatusFailed
default:
j.Status = JobStatusPending
}
if len(resp.Reason) > 0 {
j.FailedReason = string(resp.Reason)
}
// Handle workload results (oneof)
var err error
switch r := resp.WorkloadResults.(type) {
case *proto.JobResponse_Bundle:
if j.Workload.Result, err = json.Marshal(r.Bundle); err != nil {
return fmt.Errorf("failed to marshal workload results: %w", err)
}
default:
return fmt.Errorf("unsupported workload response type: %T", r)
}
return nil
}
func (j *Job) ToStreamJobRequest() (*proto.JobRequest, error) {
switch j.Workload.Type {
case JobTypeBundle:
return j.buildStreamBundleResponse()
default:
return nil, status.Errorf(status.InvalidArgument, "unknown job type: %v", j.Workload.Type)
}
}
func (j *Job) buildStreamBundleResponse() (*proto.JobRequest, error) {
var p api.BundleParameters
if err := json.Unmarshal(j.Workload.Parameters, &p); err != nil {
return nil, fmt.Errorf("invalid parameters for bundle job: %w", err)
}
return &proto.JobRequest{
ID: []byte(j.ID),
WorkloadParameters: &proto.JobRequest_Bundle{
Bundle: &proto.BundleParameters{
BundleFor: p.BundleFor,
BundleForTime: int64(p.BundleForTime),
LogFileCount: int32(p.LogFileCount),
Anonymize: p.Anonymize,
},
},
}, nil
}

View File

@@ -14,6 +14,7 @@ import (
type Client interface { type Client interface {
io.Closer io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
GetServerPublicKey() (*wgtypes.Key, error) GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)

View File

@@ -71,6 +71,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Cleanup(cleanUp) t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil) peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
jobManager := mgmt.NewJobManager(nil, store)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
@@ -108,7 +109,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
Return(true, nil). Return(true, nil).
AnyTimes() AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, jobManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -116,7 +117,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
groupsManager := groups.NewManagerMock() groupsManager := groups.NewManagerMock()
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{}) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, jobManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -12,6 +12,7 @@ import (
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
@@ -111,6 +112,25 @@ func (c *GrpcClient) ready() bool {
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages // Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function // Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error { func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
return c.withMgmtStream(ctx, func(ctx context.Context, serverPubKey wgtypes.Key) error {
return c.handleSyncStream(ctx, serverPubKey, sysInfo, msgHandler)
})
}
// Job wraps the real client's Job endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error {
return c.withMgmtStream(ctx, func(ctx context.Context, serverPubKey wgtypes.Key) error {
return c.handleJobStream(ctx, serverPubKey, msgHandler)
})
}
// withMgmtStream runs a streaming operation against the ManagementService
// It takes care of retries, connection readiness, and fetching server public key.
func (c *GrpcClient) withMgmtStream(
ctx context.Context,
handler func(ctx context.Context, serverPubKey wgtypes.Key) error,
) error {
operation := func() error { operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState()) log.Debugf("management connection state %v", c.conn.GetState())
connState := c.conn.GetState() connState := c.conn.GetState()
@@ -128,7 +148,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return err return err
} }
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler) return handler(ctx, *serverPubKey)
} }
err := backoff.Retry(operation, defaultBackoff(ctx)) err := backoff.Retry(operation, defaultBackoff(ctx))
@@ -139,12 +159,132 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return err return err
} }
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info, func (c *GrpcClient) handleJobStream(
msgHandler func(msg *proto.SyncResponse) error) error { ctx context.Context,
serverPubKey wgtypes.Key,
msgHandler func(msg *proto.JobRequest) *proto.JobResponse,
) error {
stream, err := c.realClient.Job(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to open job stream: %v", err)
return err
}
// Handshake with the server
if err := c.sendHandshake(ctx, stream, serverPubKey); err != nil {
return err
}
log.WithContext(ctx).Debug("job stream handshake sent successfully")
// Main loop: receive, process, respond
for {
jobReq, err := c.receiveJobRequest(ctx, stream, serverPubKey)
if err != nil {
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
log.WithContext(ctx).Info("job stream closed by server")
return nil
}
log.WithContext(ctx).Errorf("error receiving job request: %v", err)
return err
}
if jobReq == nil || len(jobReq.ID) == 0 {
log.WithContext(ctx).Debug("received unknown or empty job request, skipping")
continue
}
jobResp := c.processJobRequest(ctx, jobReq, msgHandler)
if err := c.sendJobResponse(ctx, stream, serverPubKey, jobResp); err != nil {
return err
}
}
}
// sendHandshake sends the initial handshake message
func (c *GrpcClient) sendHandshake(ctx context.Context, stream proto.ManagementService_JobClient, serverPubKey wgtypes.Key) error {
handshakeReq := &proto.JobRequest{
ID: []byte(uuid.New().String()),
}
encHello, err := encryption.EncryptMessage(serverPubKey, c.key, handshakeReq)
if err != nil {
log.WithContext(ctx).Errorf("failed to encrypt handshake message: %v", err)
return err
}
return stream.Send(&proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encHello,
})
}
// receiveJobRequest waits for and decrypts a job request
func (c *GrpcClient) receiveJobRequest(
ctx context.Context,
stream proto.ManagementService_JobClient,
serverPubKey wgtypes.Key,
) (*proto.JobRequest, error) {
encryptedMsg, err := stream.Recv()
if err != nil {
return nil, err
}
jobReq := &proto.JobRequest{}
if err := encryption.DecryptMessage(serverPubKey, c.key, encryptedMsg.Body, jobReq); err != nil {
log.WithContext(ctx).Warnf("failed to decrypt job request: %v", err)
return nil, err
}
return jobReq, nil
}
// processJobRequest executes the handler and ensures a valid response
func (c *GrpcClient) processJobRequest(
ctx context.Context,
jobReq *proto.JobRequest,
msgHandler func(msg *proto.JobRequest) *proto.JobResponse,
) *proto.JobResponse {
jobResp := msgHandler(jobReq)
if jobResp == nil {
jobResp = &proto.JobResponse{
ID: jobReq.ID,
Status: proto.JobStatus_failed,
Reason: []byte("handler returned nil response"),
}
log.WithContext(ctx).Warnf("job handler returned nil for job %s", string(jobReq.ID))
}
return jobResp
}
// sendJobResponse encrypts and sends a job response
func (c *GrpcClient) sendJobResponse(
ctx context.Context,
stream proto.ManagementService_JobClient,
serverPubKey wgtypes.Key,
resp *proto.JobResponse,
) error {
encResp, err := encryption.EncryptMessage(serverPubKey, c.key, resp)
if err != nil {
log.WithContext(ctx).Errorf("failed to encrypt job response for job %s: %v", string(resp.ID), err)
return err
}
if err := stream.Send(&proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encResp,
}); err != nil {
log.WithContext(ctx).Errorf("failed to send job response for job %s: %v", string(resp.ID), err)
return err
}
log.WithContext(ctx).Debugf("job response sent successfully for job %s", string(resp.ID))
return nil
}
func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
ctx, cancelStream := context.WithCancel(ctx) ctx, cancelStream := context.WithCancel(ctx)
defer cancelStream() defer cancelStream()
stream, err := c.connectToStream(ctx, serverPubKey, sysInfo) stream, err := c.connectToSyncStream(ctx, serverPubKey, sysInfo)
if err != nil { if err != nil {
log.Debugf("failed to open Management Service stream: %s", err) log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied { if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
@@ -157,7 +297,7 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key,
c.notifyConnected() c.notifyConnected()
// blocking until error // blocking until error
err = c.receiveEvents(stream, serverPubKey, msgHandler) err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler)
if err != nil { if err != nil {
c.notifyDisconnected(err) c.notifyDisconnected(err)
s, _ := gstatus.FromError(err) s, _ := gstatus.FromError(err)
@@ -186,7 +326,7 @@ func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, err
ctx, cancelStream := context.WithCancel(c.ctx) ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream() defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo) stream, err := c.connectToSyncStream(ctx, *serverPubKey, sysInfo)
if err != nil { if err != nil {
log.Debugf("failed to open Management Service stream: %s", err) log.Debugf("failed to open Management Service stream: %s", err)
return nil, err return nil, err
@@ -219,7 +359,7 @@ func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, err
return decryptedResp.GetNetworkMap(), nil return decryptedResp.GetNetworkMap(), nil
} }
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) { func (c *GrpcClient) connectToSyncStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)} req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)}
myPrivateKey := c.key myPrivateKey := c.key
@@ -238,7 +378,7 @@ func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.K
return sync, nil return sync, nil
} }
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error { func (c *GrpcClient) receiveUpdatesEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
for { for {
update, err := stream.Recv() update, err := stream.Recv()
if err == io.EOF { if err == io.EOF {

View File

@@ -20,6 +20,7 @@ type MockClient struct {
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(sysInfo *system.Info) error SyncMetaFunc func(sysInfo *system.Info) error
LogoutFunc func() error LogoutFunc func() error
JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
} }
func (m *MockClient) IsHealthy() bool { func (m *MockClient) IsHealthy() bool {
@@ -40,6 +41,13 @@ func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return m.SyncFunc(ctx, sysInfo, msgHandler) return m.SyncFunc(ctx, sysInfo, msgHandler)
} }
func (m *MockClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error {
if m.JobFunc == nil {
return nil
}
return m.JobFunc(ctx, msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) { func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
if m.GetServerPublicKeyFunc == nil { if m.GetServerPublicKeyFunc == nil {
return nil, nil return nil, nil

File diff suppressed because it is too large Load Diff

View File

@@ -48,6 +48,9 @@ service ManagementService {
// Logout logs out the peer and removes it from the management server // Logout logs out the peer and removes it from the management server
rpc Logout(EncryptedMessage) returns (Empty) {} rpc Logout(EncryptedMessage) returns (Empty) {}
// Executes a job on a target peer (e.g., debug bundle)
rpc Job(stream EncryptedMessage) returns (stream EncryptedMessage) {}
} }
message EncryptedMessage { message EncryptedMessage {
@@ -60,6 +63,42 @@ message EncryptedMessage {
int32 version = 3; int32 version = 3;
} }
message JobRequest {
bytes ID = 1;
oneof workload_parameters {
BundleParameters bundle = 10;
//OtherParameters other = 11;
}
}
enum JobStatus {
unknown_status = 0; //placeholder
succeeded = 1;
failed = 2;
}
message JobResponse{
bytes ID = 1;
JobStatus status=2;
bytes Reason=3;
oneof workload_results {
BundleResult bundle = 10;
//OtherResult other = 11;
}
}
message BundleParameters {
bool bundle_for = 1;
int64 bundle_for_time = 2;
int32 log_file_count = 3;
bool anonymize = 4;
}
message BundleResult {
string upload_key = 1;
}
message SyncRequest { message SyncRequest {
// Meta data of the peer // Meta data of the peer
PeerSystemMeta meta = 1; PeerSystemMeta meta = 1;

View File

@@ -50,6 +50,8 @@ type ManagementServiceClient interface {
SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
// Logout logs out the peer and removes it from the management server // Logout logs out the peer and removes it from the management server
Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
// Executes a job on a target peer (e.g., debug bundle)
Job(ctx context.Context, opts ...grpc.CallOption) (ManagementService_JobClient, error)
} }
type managementServiceClient struct { type managementServiceClient struct {
@@ -155,6 +157,37 @@ func (c *managementServiceClient) Logout(ctx context.Context, in *EncryptedMessa
return out, nil return out, nil
} }
func (c *managementServiceClient) Job(ctx context.Context, opts ...grpc.CallOption) (ManagementService_JobClient, error) {
stream, err := c.cc.NewStream(ctx, &ManagementService_ServiceDesc.Streams[1], "/management.ManagementService/Job", opts...)
if err != nil {
return nil, err
}
x := &managementServiceJobClient{stream}
return x, nil
}
type ManagementService_JobClient interface {
Send(*EncryptedMessage) error
Recv() (*EncryptedMessage, error)
grpc.ClientStream
}
type managementServiceJobClient struct {
grpc.ClientStream
}
func (x *managementServiceJobClient) Send(m *EncryptedMessage) error {
return x.ClientStream.SendMsg(m)
}
func (x *managementServiceJobClient) Recv() (*EncryptedMessage, error) {
m := new(EncryptedMessage)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// ManagementServiceServer is the server API for ManagementService service. // ManagementServiceServer is the server API for ManagementService service.
// All implementations must embed UnimplementedManagementServiceServer // All implementations must embed UnimplementedManagementServiceServer
// for forward compatibility // for forward compatibility
@@ -191,6 +224,8 @@ type ManagementServiceServer interface {
SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) SyncMeta(context.Context, *EncryptedMessage) (*Empty, error)
// Logout logs out the peer and removes it from the management server // Logout logs out the peer and removes it from the management server
Logout(context.Context, *EncryptedMessage) (*Empty, error) Logout(context.Context, *EncryptedMessage) (*Empty, error)
// Executes a job on a target peer (e.g., debug bundle)
Job(ManagementService_JobServer) error
mustEmbedUnimplementedManagementServiceServer() mustEmbedUnimplementedManagementServiceServer()
} }
@@ -222,6 +257,9 @@ func (UnimplementedManagementServiceServer) SyncMeta(context.Context, *Encrypted
func (UnimplementedManagementServiceServer) Logout(context.Context, *EncryptedMessage) (*Empty, error) { func (UnimplementedManagementServiceServer) Logout(context.Context, *EncryptedMessage) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented") return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented")
} }
func (UnimplementedManagementServiceServer) Job(ManagementService_JobServer) error {
return status.Errorf(codes.Unimplemented, "method Job not implemented")
}
func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {} func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {}
// UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service. // UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -382,6 +420,32 @@ func _ManagementService_Logout_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _ManagementService_Job_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(ManagementServiceServer).Job(&managementServiceJobServer{stream})
}
type ManagementService_JobServer interface {
Send(*EncryptedMessage) error
Recv() (*EncryptedMessage, error)
grpc.ServerStream
}
type managementServiceJobServer struct {
grpc.ServerStream
}
func (x *managementServiceJobServer) Send(m *EncryptedMessage) error {
return x.ServerStream.SendMsg(m)
}
func (x *managementServiceJobServer) Recv() (*EncryptedMessage, error) {
m := new(EncryptedMessage)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service. // ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@@ -424,6 +488,12 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{
Handler: _ManagementService_Sync_Handler, Handler: _ManagementService_Sync_Handler,
ServerStreams: true, ServerStreams: true,
}, },
{
StreamName: "Job",
Handler: _ManagementService_Job_Handler,
ServerStreams: true,
ClientStreams: true,
},
}, },
Metadata: "management.proto", Metadata: "management.proto",
} }