Merge branch 'main' into refactor/permissions-manager

This commit is contained in:
pascal
2026-04-07 17:35:39 +02:00
112 changed files with 10718 additions and 2106 deletions

View File

@@ -4,8 +4,6 @@ import (
"context"
"io"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
@@ -16,14 +14,18 @@ type Client interface {
io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
Register(setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
GetServerURL() string
// IsHealthy returns the current connection status without blocking.
// Used by the engine to monitor connectivity in the background.
IsHealthy() bool
// HealthCheck actively probes the management server and returns an error if unreachable.
// Used to validate connectivity before committing configuration changes.
HealthCheck() error
SyncMeta(sysInfo *system.Info) error
Logout() error
CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error)

View File

@@ -189,7 +189,7 @@ func closeManagementSilently(s *grpc.Server, listener net.Listener) {
}
}
func TestClient_GetServerPublicKey(t *testing.T) {
func TestClient_HealthCheck(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
@@ -203,12 +203,8 @@ func TestClient_GetServerPublicKey(t *testing.T) {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Error("couldn't retrieve management public key")
}
if key == nil {
t.Error("got an empty management public key")
if err := client.HealthCheck(); err != nil {
t.Errorf("health check failed: %v", err)
}
}
@@ -225,12 +221,8 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Fatal(err)
}
sysInfo := system.GetInfo(context.TODO())
_, err = client.Login(*key, sysInfo, nil, nil)
_, err = client.Login(sysInfo, nil, nil)
if err == nil {
t.Error("expecting err on unregistered login, got nil")
}
@@ -253,12 +245,8 @@ func TestClient_LoginRegistered(t *testing.T) {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
resp, err := client.Register(ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
@@ -282,13 +270,8 @@ func TestClient_Sync(t *testing.T) {
t.Fatal(err)
}
serverKey, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
_, err = client.Register(ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
@@ -304,7 +287,7 @@ func TestClient_Sync(t *testing.T) {
}
info = system.GetInfo(context.TODO())
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
_, err = remoteClient.Register(ValidKey, "", info, nil, nil)
if err != nil {
t.Fatal(err)
}
@@ -364,11 +347,6 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
t.Fatalf("error while creating testClient: %v", err)
}
key, err := testClient.GetServerPublicKey()
if err != nil {
t.Fatalf("error while getting server public key from testclient, %v", err)
}
var actualMeta *mgmtProto.PeerSystemMeta
var actualValidKey string
var wg sync.WaitGroup
@@ -405,7 +383,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
}
info := system.GetInfo(context.TODO())
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
_, err = testClient.Register(ValidKey, "", info, nil, nil)
if err != nil {
t.Errorf("error while trying to register client: %v", err)
}
@@ -505,7 +483,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
}
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
if err != nil {
return nil, err
}
@@ -517,7 +495,7 @@ func Test_GetDeviceAuthorizationFlow(t *testing.T) {
}, nil
}
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
flowInfo, err := client.GetDeviceAuthorizationFlow()
if err != nil {
t.Error("error while retrieving device auth flow information")
}
@@ -551,7 +529,7 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
}
mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
encryptedResp, err := encryption.EncryptMessage(client.key.PublicKey(), serverKey, expectedFlowInfo)
if err != nil {
return nil, err
}
@@ -563,11 +541,11 @@ func Test_GetPKCEAuthorizationFlow(t *testing.T) {
}, nil
}
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey)
flowInfo, err := client.GetPKCEAuthorizationFlow()
if err != nil {
t.Error("error while retrieving pkce auth flow information")
}
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match") //nolint:staticcheck
}

View File

@@ -202,7 +202,7 @@ func (c *GrpcClient) withMgmtStream(
return fmt.Errorf("connection to management is not ready and in %s state", connState)
}
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey()
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
@@ -404,7 +404,7 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
// GetNetworkMap return with the network map
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return nil, err
@@ -490,18 +490,24 @@ func (c *GrpcClient) receiveUpdatesEvents(stream proto.ManagementService_SyncCli
}
}
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
// HealthCheck actively probes the management server and returns an error if unreachable.
// Used to validate connectivity before committing configuration changes.
func (c *GrpcClient) HealthCheck() error {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
return errors.New(errMsgNoMgmtConnection)
}
_, err := c.getServerPublicKey()
return err
}
// getServerPublicKey fetches the server's WireGuard public key.
func (c *GrpcClient) getServerPublicKey() (*wgtypes.Key, error) {
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
defer cancel()
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, fmt.Errorf("failed while getting Management Service public key")
return nil, fmt.Errorf("failed getting Management Service public key: %w", err)
}
serverKey, err := wgtypes.ParseKey(resp.Key)
@@ -512,7 +518,8 @@ func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
return &serverKey, nil
}
// IsHealthy probes the gRPC connection and returns false on errors
// IsHealthy returns the current connection status without blocking.
// Used by the engine to monitor connectivity in the background.
func (c *GrpcClient) IsHealthy() bool {
switch c.conn.GetState() {
case connectivity.TransientFailure:
@@ -538,12 +545,17 @@ func (c *GrpcClient) IsHealthy() bool {
return true
}
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
func (c *GrpcClient) login(req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
serverKey, err := c.getServerPublicKey()
if err != nil {
return nil, err
}
loginReq, err := encryption.EncryptMessage(*serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
@@ -577,7 +589,7 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt login response: %s", err)
return nil, err
@@ -589,34 +601,40 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
func (c *GrpcClient) Register(setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
return c.login(&proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
}
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
func (c *GrpcClient) Login(sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
return c.login(&proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
}
// GetDeviceAuthorizationFlow returns a device authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
func (c *GrpcClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
}
serverKey, err := c.getServerPublicKey()
if err != nil {
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
message := &proto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil {
return nil, err
}
@@ -630,7 +648,7 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
}
flowInfoResp := &proto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
log.Error(errWithMSG)
@@ -642,15 +660,21 @@ func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.D
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
func (c *GrpcClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
}
serverKey, err := c.getServerPublicKey()
if err != nil {
return nil, err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
message := &proto.PKCEAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil {
return nil, err
}
@@ -664,7 +688,7 @@ func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKC
}
flowInfoResp := &proto.PKCEAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
err = encryption.DecryptMessage(*serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
log.Error(errWithMSG)
@@ -681,7 +705,7 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
return errors.New(errMsgNoMgmtConnection)
}
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey()
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
@@ -724,7 +748,7 @@ func (c *GrpcClient) notifyConnected() {
}
func (c *GrpcClient) Logout() error {
serverKey, err := c.GetServerPublicKey()
serverKey, err := c.getServerPublicKey()
if err != nil {
return fmt.Errorf("get server public key: %w", err)
}
@@ -751,7 +775,7 @@ func (c *GrpcClient) Logout() error {
// CreateExpose calls the management server to create a new expose service.
func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*ExposeResponse, error) {
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey()
if err != nil {
return nil, err
}
@@ -787,7 +811,7 @@ func (c *GrpcClient) CreateExpose(ctx context.Context, req ExposeRequest) (*Expo
// RenewExpose extends the TTL of an active expose session on the management server.
func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey()
if err != nil {
return err
}
@@ -810,7 +834,7 @@ func (c *GrpcClient) RenewExpose(ctx context.Context, domain string) error {
// StopExpose terminates an active expose session on the management server.
func (c *GrpcClient) StopExpose(ctx context.Context, domain string) error {
serverPubKey, err := c.GetServerPublicKey()
serverPubKey, err := c.getServerPublicKey()
if err != nil {
return err
}

View File

@@ -3,8 +3,6 @@ package client
import (
"context"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
@@ -14,12 +12,12 @@ import (
type MockClient struct {
CloseFunc func() error
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
RegisterFunc func(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func() (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func() (*proto.PKCEAuthorizationFlow, error)
GetServerURLFunc func() string
HealthCheckFunc func() error
SyncMetaFunc func(sysInfo *system.Info) error
LogoutFunc func() error
JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
@@ -53,39 +51,39 @@ func (m *MockClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequ
return m.JobFunc(ctx, msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
if m.GetServerPublicKeyFunc == nil {
return nil, nil
}
return m.GetServerPublicKeyFunc()
}
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
func (m *MockClient) Register(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.RegisterFunc == nil {
return nil, nil
}
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
return m.RegisterFunc(setupKey, jwtToken, info, sshKey, dnsLabels)
}
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
func (m *MockClient) Login(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.LoginFunc == nil {
return nil, nil
}
return m.LoginFunc(serverKey, info, sshKey, dnsLabels)
return m.LoginFunc(info, sshKey, dnsLabels)
}
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
func (m *MockClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
if m.GetDeviceAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetDeviceAuthorizationFlowFunc(serverKey)
return m.GetDeviceAuthorizationFlowFunc()
}
func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
func (m *MockClient) GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error) {
if m.GetPKCEAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetPKCEAuthorizationFlowFunc(serverKey)
return m.GetPKCEAuthorizationFlowFunc()
}
func (m *MockClient) HealthCheck() error {
if m.HealthCheckFunc == nil {
return nil
}
return m.HealthCheckFunc()
}
// GetNetworkMap mock implementation of GetNetworkMap from Client interface.

View File

@@ -0,0 +1,112 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// AzureIDPAPI APIs for Azure AD IDP integrations
type AzureIDPAPI struct {
c *Client
}
// List retrieves all Azure AD IDP integrations
func (a *AzureIDPAPI) List(ctx context.Context) ([]api.AzureIntegration, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/azure-idp", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.AzureIntegration](resp)
return ret, err
}
// Get retrieves a specific Azure AD IDP integration by ID
func (a *AzureIDPAPI) Get(ctx context.Context, integrationID string) (*api.AzureIntegration, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/azure-idp/"+integrationID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.AzureIntegration](resp)
return &ret, err
}
// Create creates a new Azure AD IDP integration
func (a *AzureIDPAPI) Create(ctx context.Context, request api.CreateAzureIntegrationRequest) (*api.AzureIntegration, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/azure-idp", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.AzureIntegration](resp)
return &ret, err
}
// Update updates an existing Azure AD IDP integration
func (a *AzureIDPAPI) Update(ctx context.Context, integrationID string, request api.UpdateAzureIntegrationRequest) (*api.AzureIntegration, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/azure-idp/"+integrationID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.AzureIntegration](resp)
return &ret, err
}
// Delete deletes an Azure AD IDP integration
func (a *AzureIDPAPI) Delete(ctx context.Context, integrationID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/azure-idp/"+integrationID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// Sync triggers a manual sync for an Azure AD IDP integration
func (a *AzureIDPAPI) Sync(ctx context.Context, integrationID string) (*api.SyncResult, error) {
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/azure-idp/"+integrationID+"/sync", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.SyncResult](resp)
return &ret, err
}
// GetLogs retrieves synchronization logs for an Azure AD IDP integration
func (a *AzureIDPAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/azure-idp/"+integrationID+"/logs", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp)
return ret, err
}

View File

@@ -0,0 +1,252 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var testAzureIntegration = api.AzureIntegration{
Id: 1,
Enabled: true,
ClientId: "12345678-1234-1234-1234-123456789012",
TenantId: "87654321-4321-4321-4321-210987654321",
SyncInterval: 300,
GroupPrefixes: []string{"eng-"},
UserGroupPrefixes: []string{"dev-"},
Host: "microsoft.com",
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
func TestAzureIDP_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.AzureIntegration{testAzureIntegration})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.AzureIDP.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testAzureIntegration, ret[0])
})
}
func TestAzureIDP_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.AzureIDP.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestAzureIDP_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testAzureIntegration)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.AzureIDP.Get(context.Background(), "int-1")
require.NoError(t, err)
assert.Equal(t, testAzureIntegration, *ret)
})
}
func TestAzureIDP_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.AzureIDP.Get(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestAzureIDP_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.CreateAzureIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "12345678-1234-1234-1234-123456789012", req.ClientId)
retBytes, _ := json.Marshal(testAzureIntegration)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.AzureIDP.Create(context.Background(), api.CreateAzureIntegrationRequest{
ClientId: "12345678-1234-1234-1234-123456789012",
ClientSecret: "secret",
TenantId: "87654321-4321-4321-4321-210987654321",
Host: api.CreateAzureIntegrationRequestHostMicrosoftCom,
GroupPrefixes: &[]string{"eng-"},
})
require.NoError(t, err)
assert.Equal(t, testAzureIntegration, *ret)
})
}
func TestAzureIDP_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.AzureIDP.Create(context.Background(), api.CreateAzureIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestAzureIDP_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.UpdateAzureIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, *req.Enabled)
retBytes, _ := json.Marshal(testAzureIntegration)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.AzureIDP.Update(context.Background(), "int-1", api.UpdateAzureIntegrationRequest{
Enabled: ptr(true),
})
require.NoError(t, err)
assert.Equal(t, testAzureIntegration, *ret)
})
}
func TestAzureIDP_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.AzureIDP.Update(context.Background(), "int-1", api.UpdateAzureIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestAzureIDP_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.AzureIDP.Delete(context.Background(), "int-1")
require.NoError(t, err)
})
}
func TestAzureIDP_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.AzureIDP.Delete(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestAzureIDP_Sync_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(api.SyncResult{Result: ptr("ok")})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.AzureIDP.Sync(context.Background(), "int-1")
require.NoError(t, err)
assert.Equal(t, "ok", *ret.Result)
})
}
func TestAzureIDP_Sync_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.AzureIDP.Sync(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestAzureIDP_GetLogs_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.AzureIDP.GetLogs(context.Background(), "int-1")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testSyncLog, ret[0])
})
}
func TestAzureIDP_GetLogs_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/azure-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.AzureIDP.GetLogs(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Empty(t, ret)
})
}

View File

@@ -110,6 +110,15 @@ type Client struct {
// see more: https://docs.netbird.io/api/resources/scim
SCIM *SCIMAPI
// GoogleIDP NetBird Google Workspace IDP integration APIs
GoogleIDP *GoogleIDPAPI
// AzureIDP NetBird Azure AD IDP integration APIs
AzureIDP *AzureIDPAPI
// OktaScimIDP NetBird Okta SCIM IDP integration APIs
OktaScimIDP *OktaScimIDPAPI
// EventStreaming NetBird Event Streaming integration APIs
// see more: https://docs.netbird.io/api/resources/event-streaming
EventStreaming *EventStreamingAPI
@@ -185,6 +194,9 @@ func (c *Client) initialize() {
c.MSP = &MSPAPI{c}
c.EDR = &EDRAPI{c}
c.SCIM = &SCIMAPI{c}
c.GoogleIDP = &GoogleIDPAPI{c}
c.AzureIDP = &AzureIDPAPI{c}
c.OktaScimIDP = &OktaScimIDPAPI{c}
c.EventStreaming = &EventStreamingAPI{c}
c.IdentityProviders = &IdentityProvidersAPI{c}
c.Ingress = &IngressAPI{c}

View File

@@ -265,6 +265,65 @@ func (a *EDRAPI) DeleteHuntressIntegration(ctx context.Context) error {
return nil
}
// GetFleetDMIntegration retrieves the EDR FleetDM integration.
func (a *EDRAPI) GetFleetDMIntegration(ctx context.Context) (*api.EDRFleetDMResponse, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/edr/fleetdm", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRFleetDMResponse](resp)
return &ret, err
}
// CreateFleetDMIntegration creates a new EDR FleetDM integration.
func (a *EDRAPI) CreateFleetDMIntegration(ctx context.Context, request api.EDRFleetDMRequest) (*api.EDRFleetDMResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/edr/fleetdm", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRFleetDMResponse](resp)
return &ret, err
}
// UpdateFleetDMIntegration updates an existing EDR FleetDM integration.
func (a *EDRAPI) UpdateFleetDMIntegration(ctx context.Context, request api.EDRFleetDMRequest) (*api.EDRFleetDMResponse, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/edr/fleetdm", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.EDRFleetDMResponse](resp)
return &ret, err
}
// DeleteFleetDMIntegration deletes the EDR FleetDM integration.
func (a *EDRAPI) DeleteFleetDMIntegration(ctx context.Context) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/edr/fleetdm", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// BypassPeerCompliance bypasses compliance for a non-compliant peer
// See more: https://docs.netbird.io/api/resources/edr#bypass-peer-compliance
func (a *EDRAPI) BypassPeerCompliance(ctx context.Context, peerID string) (*api.BypassResponse, error) {

View File

@@ -0,0 +1,112 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// GoogleIDPAPI APIs for Google Workspace IDP integrations
type GoogleIDPAPI struct {
c *Client
}
// List retrieves all Google Workspace IDP integrations
func (a *GoogleIDPAPI) List(ctx context.Context) ([]api.GoogleIntegration, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/google-idp", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.GoogleIntegration](resp)
return ret, err
}
// Get retrieves a specific Google Workspace IDP integration by ID
func (a *GoogleIDPAPI) Get(ctx context.Context, integrationID string) (*api.GoogleIntegration, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/google-idp/"+integrationID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.GoogleIntegration](resp)
return &ret, err
}
// Create creates a new Google Workspace IDP integration
func (a *GoogleIDPAPI) Create(ctx context.Context, request api.CreateGoogleIntegrationRequest) (*api.GoogleIntegration, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/google-idp", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.GoogleIntegration](resp)
return &ret, err
}
// Update updates an existing Google Workspace IDP integration
func (a *GoogleIDPAPI) Update(ctx context.Context, integrationID string, request api.UpdateGoogleIntegrationRequest) (*api.GoogleIntegration, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/google-idp/"+integrationID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.GoogleIntegration](resp)
return &ret, err
}
// Delete deletes a Google Workspace IDP integration
func (a *GoogleIDPAPI) Delete(ctx context.Context, integrationID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/google-idp/"+integrationID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// Sync triggers a manual sync for a Google Workspace IDP integration
func (a *GoogleIDPAPI) Sync(ctx context.Context, integrationID string) (*api.SyncResult, error) {
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/google-idp/"+integrationID+"/sync", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.SyncResult](resp)
return &ret, err
}
// GetLogs retrieves synchronization logs for a Google Workspace IDP integration
func (a *GoogleIDPAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/google-idp/"+integrationID+"/logs", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp)
return ret, err
}

View File

@@ -0,0 +1,248 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var testGoogleIntegration = api.GoogleIntegration{
Id: 1,
Enabled: true,
CustomerId: "C01234567",
SyncInterval: 300,
GroupPrefixes: []string{"eng-"},
UserGroupPrefixes: []string{"dev-"},
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
func TestGoogleIDP_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.GoogleIntegration{testGoogleIntegration})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GoogleIDP.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testGoogleIntegration, ret[0])
})
}
func TestGoogleIDP_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GoogleIDP.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestGoogleIDP_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testGoogleIntegration)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GoogleIDP.Get(context.Background(), "int-1")
require.NoError(t, err)
assert.Equal(t, testGoogleIntegration, *ret)
})
}
func TestGoogleIDP_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GoogleIDP.Get(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestGoogleIDP_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.CreateGoogleIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "C01234567", req.CustomerId)
retBytes, _ := json.Marshal(testGoogleIntegration)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GoogleIDP.Create(context.Background(), api.CreateGoogleIntegrationRequest{
CustomerId: "C01234567",
ServiceAccountKey: "key-data",
GroupPrefixes: &[]string{"eng-"},
})
require.NoError(t, err)
assert.Equal(t, testGoogleIntegration, *ret)
})
}
func TestGoogleIDP_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GoogleIDP.Create(context.Background(), api.CreateGoogleIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestGoogleIDP_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.UpdateGoogleIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, *req.Enabled)
retBytes, _ := json.Marshal(testGoogleIntegration)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GoogleIDP.Update(context.Background(), "int-1", api.UpdateGoogleIntegrationRequest{
Enabled: ptr(true),
})
require.NoError(t, err)
assert.Equal(t, testGoogleIntegration, *ret)
})
}
func TestGoogleIDP_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GoogleIDP.Update(context.Background(), "int-1", api.UpdateGoogleIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestGoogleIDP_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.GoogleIDP.Delete(context.Background(), "int-1")
require.NoError(t, err)
})
}
func TestGoogleIDP_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.GoogleIDP.Delete(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestGoogleIDP_Sync_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(api.SyncResult{Result: ptr("ok")})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GoogleIDP.Sync(context.Background(), "int-1")
require.NoError(t, err)
assert.Equal(t, "ok", *ret.Result)
})
}
func TestGoogleIDP_Sync_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp/int-1/sync", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GoogleIDP.Sync(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestGoogleIDP_GetLogs_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GoogleIDP.GetLogs(context.Background(), "int-1")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testSyncLog, ret[0])
})
}
func TestGoogleIDP_GetLogs_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/google-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GoogleIDP.GetLogs(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Empty(t, ret)
})
}

View File

@@ -0,0 +1,112 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// OktaScimIDPAPI APIs for Okta SCIM IDP integrations
type OktaScimIDPAPI struct {
c *Client
}
// List retrieves all Okta SCIM IDP integrations
func (a *OktaScimIDPAPI) List(ctx context.Context) ([]api.OktaScimIntegration, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/okta-scim-idp", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.OktaScimIntegration](resp)
return ret, err
}
// Get retrieves a specific Okta SCIM IDP integration by ID
func (a *OktaScimIDPAPI) Get(ctx context.Context, integrationID string) (*api.OktaScimIntegration, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/okta-scim-idp/"+integrationID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.OktaScimIntegration](resp)
return &ret, err
}
// Create creates a new Okta SCIM IDP integration
func (a *OktaScimIDPAPI) Create(ctx context.Context, request api.CreateOktaScimIntegrationRequest) (*api.OktaScimIntegration, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/okta-scim-idp", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.OktaScimIntegration](resp)
return &ret, err
}
// Update updates an existing Okta SCIM IDP integration
func (a *OktaScimIDPAPI) Update(ctx context.Context, integrationID string, request api.UpdateOktaScimIntegrationRequest) (*api.OktaScimIntegration, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/integrations/okta-scim-idp/"+integrationID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.OktaScimIntegration](resp)
return &ret, err
}
// Delete deletes an Okta SCIM IDP integration
func (a *OktaScimIDPAPI) Delete(ctx context.Context, integrationID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/integrations/okta-scim-idp/"+integrationID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// RegenerateToken regenerates the SCIM API token for an Okta SCIM integration
func (a *OktaScimIDPAPI) RegenerateToken(ctx context.Context, integrationID string) (*api.ScimTokenResponse, error) {
resp, err := a.c.NewRequest(ctx, "POST", "/api/integrations/okta-scim-idp/"+integrationID+"/token", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.ScimTokenResponse](resp)
return &ret, err
}
// GetLogs retrieves synchronization logs for an Okta SCIM IDP integration
func (a *OktaScimIDPAPI) GetLogs(ctx context.Context, integrationID string) ([]api.IdpIntegrationSyncLog, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/integrations/okta-scim-idp/"+integrationID+"/logs", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.IdpIntegrationSyncLog](resp)
return ret, err
}

View File

@@ -0,0 +1,246 @@
//go:build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var testOktaScimIntegration = api.OktaScimIntegration{
Id: 1,
AuthToken: "****",
Enabled: true,
GroupPrefixes: []string{"eng-"},
UserGroupPrefixes: []string{"dev-"},
LastSyncedAt: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
}
func TestOktaScimIDP_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.OktaScimIntegration{testOktaScimIntegration})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.OktaScimIDP.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testOktaScimIntegration, ret[0])
})
}
func TestOktaScimIDP_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.OktaScimIDP.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestOktaScimIDP_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testOktaScimIntegration)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.OktaScimIDP.Get(context.Background(), "int-1")
require.NoError(t, err)
assert.Equal(t, testOktaScimIntegration, *ret)
})
}
func TestOktaScimIDP_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.OktaScimIDP.Get(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestOktaScimIDP_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.CreateOktaScimIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "my-okta-connection", req.ConnectionName)
retBytes, _ := json.Marshal(testOktaScimIntegration)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.OktaScimIDP.Create(context.Background(), api.CreateOktaScimIntegrationRequest{
ConnectionName: "my-okta-connection",
GroupPrefixes: &[]string{"eng-"},
})
require.NoError(t, err)
assert.Equal(t, testOktaScimIntegration, *ret)
})
}
func TestOktaScimIDP_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.OktaScimIDP.Create(context.Background(), api.CreateOktaScimIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestOktaScimIDP_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.UpdateOktaScimIntegrationRequest
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, *req.Enabled)
retBytes, _ := json.Marshal(testOktaScimIntegration)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.OktaScimIDP.Update(context.Background(), "int-1", api.UpdateOktaScimIntegrationRequest{
Enabled: ptr(true),
})
require.NoError(t, err)
assert.Equal(t, testOktaScimIntegration, *ret)
})
}
func TestOktaScimIDP_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.OktaScimIDP.Update(context.Background(), "int-1", api.UpdateOktaScimIntegrationRequest{})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestOktaScimIDP_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.OktaScimIDP.Delete(context.Background(), "int-1")
require.NoError(t, err)
})
}
func TestOktaScimIDP_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp/int-1", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.OktaScimIDP.Delete(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestOktaScimIDP_RegenerateToken_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/token", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
retBytes, _ := json.Marshal(testScimToken)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.OktaScimIDP.RegenerateToken(context.Background(), "int-1")
require.NoError(t, err)
assert.Equal(t, testScimToken, *ret)
})
}
func TestOktaScimIDP_RegenerateToken_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/token", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.OktaScimIDP.RegenerateToken(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Nil(t, ret)
})
}
func TestOktaScimIDP_GetLogs_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.IdpIntegrationSyncLog{testSyncLog})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.OktaScimIDP.GetLogs(context.Background(), "int-1")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testSyncLog, ret[0])
})
}
func TestOktaScimIDP_GetLogs_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/integrations/okta-scim-idp/int-1/logs", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.OktaScimIDP.GetLogs(context.Background(), "int-1")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Empty(t, ret)
})
}