[misc] Move shared components to shared directory (#4286)

Moved the following directories:

```
  - management/client → shared/management/client
  - management/domain → shared/management/domain
  - management/proto → shared/management/proto
  - signal/client → shared/signal/client
  - signal/proto → shared/signal/proto
  - relay/client → shared/relay/client
  - relay/auth → shared/relay/auth
```

and adjusted import paths
This commit is contained in:
Viktor Liu
2025-08-05 15:22:58 +02:00
committed by GitHub
parent 3d3c4c5844
commit 1d5e871bdf
172 changed files with 181 additions and 152 deletions

View File

@@ -0,0 +1,26 @@
package client
import (
"context"
"io"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Client interface {
io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
IsHealthy() bool
SyncMeta(sysInfo *system.Info) error
Logout() error
}

View File

@@ -0,0 +1,554 @@
package client
import (
"context"
"net"
"os"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/util"
)
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
func TestMain(m *testing.M) {
_ = util.InitLog("debug", util.LogConsole)
code := m.Run()
os.Exit(code)
}
func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Helper()
level, _ := log.ParseLevel("debug")
log.SetLevel(level)
config := &types.Config{}
_, err := util.ReadJson("../server/testdata/management.json", config)
if err != nil {
t.Fatal(err)
}
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.
EXPECT().
GetSettings(
gomock.Any(),
gomock.Any(),
gomock.Any(),
).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.
EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManagerMock := permissions.NewMockManager(ctrl)
permissionsManagerMock.
EXPECT().
ValidateUserPermissions(
gomock.Any(),
gomock.Any(),
gomock.Any(),
gomock.Any(),
gomock.Any(),
).
Return(true, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
return
}
}()
return s, lis
}
func startMockManagement(t *testing.T) (*grpc.Server, net.Listener, *mock_server.ManagementServiceServerMock, wgtypes.Key) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := grpc.NewServer()
serverKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
mgmtMockServer := &mock_server.ManagementServiceServerMock{
GetServerKeyFunc: func(context.Context, *mgmtProto.Empty) (*mgmtProto.ServerKeyResponse, error) {
response := &mgmtProto.ServerKeyResponse{
Key: serverKey.PublicKey().String(),
}
return response, nil
},
}
mgmtProto.RegisterManagementServiceServer(s, mgmtMockServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
return
}
}()
return s, lis, mgmtMockServer, serverKey
}
func closeManagementSilently(s *grpc.Server, listener net.Listener) {
s.GracefulStop()
err := listener.Close()
if err != nil {
log.Warnf("error while closing management listener %v", err)
return
}
}
func TestClient_GetServerPublicKey(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Error("couldn't retrieve management public key")
}
if key == nil {
t.Error("got an empty management public key")
}
}
func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Fatal(err)
}
sysInfo := system.GetInfo(context.TODO())
_, err = client.Login(*key, sysInfo, nil, nil)
if err == nil {
t.Error("expecting err on unregistered login, got nil")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.PermissionDenied {
t.Errorf("expecting err code %d denied on unregistered login got %d", codes.PermissionDenied, s.Code())
}
}
func TestClient_LoginRegistered(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
if resp == nil {
t.Error("expecting non nil response, got nil")
}
}
func TestClient_Sync(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
serverKey, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
// create and register second peer (we should receive on Sync request)
remoteKey, err := wgtypes.GenerateKey()
if err != nil {
t.Error(err)
}
remoteClient, err := NewClient(context.TODO(), listener.Addr().String(), remoteKey, false)
if err != nil {
t.Fatal(err)
}
info = system.GetInfo(context.TODO())
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
if err != nil {
t.Fatal(err)
}
ch := make(chan *mgmtProto.SyncResponse, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
err = client.Sync(ctx, info, func(msg *mgmtProto.SyncResponse) error {
ch <- msg
return nil
})
if err != nil {
return
}
}()
select {
case resp := <-ch:
if resp.GetPeerConfig() == nil {
t.Error("expecting non nil PeerConfig got nil")
}
if resp.GetNetbirdConfig() == nil {
t.Error("expecting non nil NetbirdConfig got nil")
}
if len(resp.GetRemotePeers()) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
return
}
if resp.GetRemotePeersIsEmpty() == true {
t.Error("expecting RemotePeers property to be false, got true")
}
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey())
}
case <-time.After(3 * time.Second):
t.Error("timeout waiting for test to finish")
}
}
func Test_SystemMetaDataFromClient(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
serverAddr := lis.Addr().String()
ctx := context.Background()
testClient, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
key, err := testClient.GetServerPublicKey()
if err != nil {
t.Fatalf("error while getting server public key from testclient, %v", err)
}
var actualMeta *mgmtProto.PeerSystemMeta
var actualValidKey string
var wg sync.WaitGroup
wg.Add(1)
mgmtMockServer.LoginFunc = func(ctx context.Context, msg *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey())
if err != nil {
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey)
return nil, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", msg.WgPubKey)
}
loginReq := &mgmtProto.LoginRequest{}
err = encryption.DecryptMessage(peerKey, serverKey, msg.Body, loginReq)
if err != nil {
log.Fatal(err)
}
actualMeta = loginReq.GetMeta()
actualValidKey = loginReq.GetSetupKey()
wg.Done()
loginResp := &mgmtProto.LoginResponse{}
encryptedResp, err := encryption.EncryptMessage(peerKey, serverKey, loginResp)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
info := system.GetInfo(context.TODO())
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
if err != nil {
t.Errorf("error while trying to register client: %v", err)
}
wg.Wait()
protoNetAddr := make([]*mgmtProto.NetworkAddress, 0, len(info.NetworkAddresses))
for _, addr := range info.NetworkAddresses {
protoNetAddr = append(protoNetAddr, &mgmtProto.NetworkAddress{
NetIP: addr.NetIP.String(),
Mac: addr.Mac,
})
}
expectedMeta := &mgmtProto.PeerSystemMeta{
Hostname: info.Hostname,
GoOS: info.GoOS,
Kernel: info.Kernel,
Platform: info.Platform,
OS: info.OS,
Core: info.OSVersion,
OSVersion: info.OSVersion,
NetbirdVersion: info.NetbirdVersion,
KernelVersion: info.KernelVersion,
NetworkAddresses: protoNetAddr,
SysSerialNumber: info.SystemSerialNumber,
SysProductName: info.SystemProductName,
SysManufacturer: info.SystemManufacturer,
Environment: &mgmtProto.Environment{Cloud: info.Environment.Cloud, Platform: info.Environment.Platform},
}
assert.Equal(t, ValidKey, actualValidKey)
if !isEqual(expectedMeta, actualMeta) {
t.Errorf("expected and actual meta are not equal")
}
}
func isEqual(a, b *mgmtProto.PeerSystemMeta) bool {
if len(a.NetworkAddresses) != len(b.NetworkAddresses) {
return false
}
for _, addr := range a.GetNetworkAddresses() {
var found bool
for _, oAddr := range b.GetNetworkAddresses() {
if addr.GetMac() == oAddr.GetMac() && addr.GetNetIP() == oAddr.GetNetIP() {
found = true
continue
}
}
if !found {
return false
}
}
log.Infof("------")
return a.GetHostname() == b.GetHostname() &&
a.GetGoOS() == b.GetGoOS() &&
a.GetKernel() == b.GetKernel() &&
a.GetKernelVersion() == b.GetKernelVersion() &&
a.GetCore() == b.GetCore() &&
a.GetPlatform() == b.GetPlatform() &&
a.GetOS() == b.GetOS() &&
a.GetOSVersion() == b.GetOSVersion() &&
a.GetNetbirdVersion() == b.GetNetbirdVersion() &&
a.GetUiVersion() == b.GetUiVersion() &&
a.GetSysSerialNumber() == b.GetSysSerialNumber() &&
a.GetSysProductName() == b.GetSysProductName() &&
a.GetSysManufacturer() == b.GetSysManufacturer() &&
a.GetEnvironment().Cloud == b.GetEnvironment().Cloud &&
a.GetEnvironment().Platform == b.GetEnvironment().Platform
}
func Test_GetDeviceAuthorizationFlow(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
serverAddr := lis.Addr().String()
ctx := context.Background()
client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
expectedFlowInfo := &mgmtProto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &mgmtProto.ProviderConfig{ClientID: "client"},
}
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
if err != nil {
t.Error("error while retrieving device auth flow information")
}
assert.Equal(t, expectedFlowInfo.Provider, flowInfo.Provider, "provider should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
}
func Test_GetPKCEAuthorizationFlow(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
serverAddr := lis.Addr().String()
ctx := context.Background()
client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
expectedFlowInfo := &mgmtProto.PKCEAuthorizationFlow{
ProviderConfig: &mgmtProto.ProviderConfig{
ClientID: "client",
ClientSecret: "secret",
},
}
mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey)
if err != nil {
t.Error("error while retrieving pkce auth flow information")
}
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match")
}

View File

@@ -0,0 +1,19 @@
package common
// LoginFlag introduces additional login flags to the PKCE authorization request
type LoginFlag uint8
const (
// LoginFlagPrompt adds prompt=login to the authorization request
LoginFlagPrompt LoginFlag = iota
// LoginFlagMaxAge0 adds max_age=0 to the authorization request
LoginFlagMaxAge0
)
func (l LoginFlag) IsPromptLogin() bool {
return l == LoginFlagPrompt
}
func (l LoginFlag) IsMaxAge0Login() bool {
return l == LoginFlagMaxAge0
}

View File

@@ -0,0 +1,3 @@
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=

View File

@@ -0,0 +1,584 @@
package client
import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
nbgrpc "github.com/netbirdio/netbird/util/grpc"
)
const ConnectTimeout = 10 * time.Second
const (
errMsgMgmtPublicKey = "failed getting Management Service public key: %s"
errMsgNoMgmtConnection = "no connection to management"
)
// ConnStateNotifier is a wrapper interface of the status recorders
type ConnStateNotifier interface {
MarkManagementDisconnected(error)
MarkManagementConnected()
}
type GrpcClient struct {
key wgtypes.Key
realClient proto.ManagementServiceClient
ctx context.Context
conn *grpc.ClientConn
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
}
// NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
var conn *grpc.ClientConn
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
}
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed creating connection to Management Service: %v", err)
return nil, err
}
realClient := proto.NewManagementServiceClient(conn)
return &GrpcClient{
key: ourPrivateKey,
realClient: realClient,
ctx: ctx,
conn: conn,
connStateCallbackLock: sync.RWMutex{},
}, nil
}
// Close closes connection to the Management Service
func (c *GrpcClient) Close() error {
return c.conn.Close()
}
// SetConnStateListener set the ConnStateNotifier
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallbackLock.Lock()
defer c.connStateCallbackLock.Unlock()
c.connStateCallback = notifier
}
// defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 1,
Multiplier: 1.7,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// ready indicates whether the client is okay and ready to be used
// for now it just checks whether gRPC connection to the service is ready
func (c *GrpcClient) ready() bool {
return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle
}
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState())
connState := c.conn.GetState()
if connState == connectivity.Shutdown {
return backoff.Permanent(fmt.Errorf("connection to management has been shut down"))
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
c.conn.WaitForStateChange(ctx, connState)
return fmt.Errorf("connection to management is not ready and in %s state", connState)
}
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
}
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler)
}
err := backoff.Retry(operation, defaultBackoff(ctx))
if err != nil {
log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err)
}
return err
}
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
msgHandler func(msg *proto.SyncResponse) error) error {
ctx, cancelStream := context.WithCancel(ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
}
return err
}
log.Infof("connected to the Management Service stream")
c.notifyConnected()
// blocking until error
err = c.receiveEvents(stream, serverPubKey, msgHandler)
if err != nil {
c.notifyDisconnected(err)
s, _ := gstatus.FromError(err)
switch s.Code() {
case codes.PermissionDenied:
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
default:
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
}
return nil
}
// GetNetworkMap return with the network map
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return nil, err
}
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
return nil, err
}
defer func() {
_ = stream.CloseSend()
}()
update, err := stream.Recv()
if err == io.EOF {
log.Debugf("Management stream has been closed by server: %s", err)
return nil, err
}
if err != nil {
log.Debugf("disconnected from Management Service sync stream: %v", err)
return nil, err
}
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(*serverPubKey, c.key, update.Body, decryptedResp)
if err != nil {
log.Errorf("failed decrypting update message from Management Service: %s", err)
return nil, err
}
if decryptedResp.GetNetworkMap() == nil {
return nil, fmt.Errorf("invalid msg, required network map")
}
return decryptedResp.GetNetworkMap(), nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)}
myPrivateKey := c.key
myPublicKey := myPrivateKey.PublicKey()
encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req)
if err != nil {
log.Errorf("failed encrypting message: %s", err)
return nil, err
}
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
sync, err := c.realClient.Sync(ctx, syncReq)
if err != nil {
return nil, err
}
return sync, nil
}
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
for {
update, err := stream.Recv()
if err == io.EOF {
log.Debugf("Management stream has been closed by server: %s", err)
return err
}
if err != nil {
log.Debugf("disconnected from Management Service sync stream: %v", err)
return err
}
log.Debugf("got an update message from Management Service")
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp)
if err != nil {
log.Errorf("failed decrypting update message from Management Service: %s", err)
return err
}
if err := msgHandler(decryptedResp); err != nil {
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
}
}
}
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
defer cancel()
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, fmt.Errorf("failed while getting Management Service public key")
}
serverKey, err := wgtypes.ParseKey(resp.Key)
if err != nil {
return nil, err
}
return &serverKey, nil
}
// IsHealthy probes the gRPC connection and returns false on errors
func (c *GrpcClient) IsHealthy() bool {
switch c.conn.GetState() {
case connectivity.TransientFailure:
return false
case connectivity.Connecting:
return true
case connectivity.Shutdown:
return true
case connectivity.Idle:
case connectivity.Ready:
}
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
defer cancel()
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
if err != nil {
c.notifyDisconnected(err)
log.Warnf("health check returned: %s", err)
return false
}
c.notifyConnected()
return true
}
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
}
var resp *proto.EncryptedMessage
operation := func() error {
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
defer cancel()
var err error
resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
// retry only on context canceled
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled {
return err
}
return backoff.Permanent(err)
}
return nil
}
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
if err != nil {
log.Errorf("failed to login to Management Service: %v", err)
return nil, err
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt login response: %s", err)
return nil, err
}
return loginResp, nil
}
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
}
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
}
// GetDeviceAuthorizationFlow returns a device authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
message := &proto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
if err != nil {
return nil, err
}
resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG},
)
if err != nil {
return nil, err
}
flowInfoResp := &proto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
log.Error(errWithMSG)
return nil, errWithMSG
}
return flowInfoResp, nil
}
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
message := &proto.PKCEAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
if err != nil {
return nil, err
}
resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG,
})
if err != nil {
return nil, err
}
flowInfoResp := &proto.PKCEAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
log.Error(errWithMSG)
return nil, errWithMSG
}
return flowInfoResp, nil
}
// SyncMeta sends updated system metadata to the Management Service.
// It should be used if there is changes on peer posture check after initial sync.
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
if !c.ready() {
return errors.New(errMsgNoMgmtConnection)
}
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
}
syncMetaReq, err := encryption.EncryptMessage(*serverPubKey, c.key, &proto.SyncMetaRequest{Meta: infoToMetaData(sysInfo)})
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel()
_, err = c.realClient.SyncMeta(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: syncMetaReq,
})
return err
}
func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkManagementDisconnected(err)
}
func (c *GrpcClient) notifyConnected() {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkManagementConnected()
}
func (c *GrpcClient) Logout() error {
serverKey, err := c.GetServerPublicKey()
if err != nil {
return fmt.Errorf("get server public key: %w", err)
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*5)
defer cancel()
message := &proto.Empty{}
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil {
return fmt.Errorf("encrypt logout message: %w", err)
}
_, err = c.realClient.Logout(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG,
})
if err != nil {
return fmt.Errorf("logout: %w", err)
}
return nil
}
func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
if info == nil {
return nil
}
addresses := make([]*proto.NetworkAddress, 0, len(info.NetworkAddresses))
for _, addr := range info.NetworkAddresses {
addresses = append(addresses, &proto.NetworkAddress{
NetIP: addr.NetIP.String(),
Mac: addr.Mac,
})
}
files := make([]*proto.File, 0, len(info.Files))
for _, file := range info.Files {
files = append(files, &proto.File{
Path: file.Path,
Exist: file.Exist,
ProcessIsRunning: file.ProcessIsRunning,
})
}
return &proto.PeerSystemMeta{
Hostname: info.Hostname,
GoOS: info.GoOS,
OS: info.OS,
Core: info.OSVersion,
OSVersion: info.OSVersion,
Platform: info.Platform,
Kernel: info.Kernel,
NetbirdVersion: info.NetbirdVersion,
UiVersion: info.UIVersion,
KernelVersion: info.KernelVersion,
NetworkAddresses: addresses,
SysSerialNumber: info.SystemSerialNumber,
SysManufacturer: info.SystemManufacturer,
SysProductName: info.SystemProductName,
Environment: &proto.Environment{
Cloud: info.Environment.Cloud,
Platform: info.Environment.Platform,
},
Files: files,
Flags: &proto.Flags{
RosenpassEnabled: info.RosenpassEnabled,
RosenpassPermissive: info.RosenpassPermissive,
ServerSSHAllowed: info.ServerSSHAllowed,
DisableClientRoutes: info.DisableClientRoutes,
DisableServerRoutes: info.DisableServerRoutes,
DisableDNS: info.DisableDNS,
DisableFirewall: info.DisableFirewall,
BlockLANAccess: info.BlockLANAccess,
BlockInbound: info.BlockInbound,
LazyConnectionEnabled: info.LazyConnectionEnabled,
},
}
}

View File

@@ -0,0 +1,95 @@
package client
import (
"context"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
)
type MockClient struct {
CloseFunc func() error
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(sysInfo *system.Info) error
LogoutFunc func() error
}
func (m *MockClient) IsHealthy() bool {
return true
}
func (m *MockClient) Close() error {
if m.CloseFunc == nil {
return nil
}
return m.CloseFunc()
}
func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
if m.SyncFunc == nil {
return nil
}
return m.SyncFunc(ctx, sysInfo, msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
if m.GetServerPublicKeyFunc == nil {
return nil, nil
}
return m.GetServerPublicKeyFunc()
}
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.RegisterFunc == nil {
return nil, nil
}
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
}
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.LoginFunc == nil {
return nil, nil
}
return m.LoginFunc(serverKey, info, sshKey, dnsLabels)
}
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
if m.GetDeviceAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetDeviceAuthorizationFlowFunc(serverKey)
}
func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
if m.GetPKCEAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetPKCEAuthorizationFlow(serverKey)
}
// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface
func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) {
return nil, nil
}
func (m *MockClient) SyncMeta(sysInfo *system.Info) error {
if m.SyncMetaFunc == nil {
return nil
}
return m.SyncMetaFunc(sysInfo)
}
func (m *MockClient) Logout() error {
if m.LogoutFunc == nil {
return nil
}
return m.LogoutFunc()
}

View File

@@ -0,0 +1,60 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// AccountsAPI APIs for accounts, do not use directly
type AccountsAPI struct {
c *Client
}
// List list all accounts, only returns one account always
// See more: https://docs.netbird.io/api/resources/accounts#list-all-accounts
func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/accounts", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Account](resp)
return ret, err
}
// Update update account settings
// See more: https://docs.netbird.io/api/resources/accounts#update-an-account
func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.PutApiAccountsAccountIdJSONRequestBody) (*api.Account, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Account](resp)
return &ret, err
}
// Delete delete account
// See more: https://docs.netbird.io/api/resources/accounts#delete-an-account
func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,183 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testAccount = api.Account{
Id: "Test",
Settings: api.AccountSettings{
Extra: &api.AccountExtraSettings{
PeerApprovalEnabled: false,
},
GroupsPropagationEnabled: ptr(true),
JwtGroupsEnabled: ptr(false),
PeerInactivityExpiration: 7,
PeerInactivityExpirationEnabled: true,
PeerLoginExpiration: 24,
PeerLoginExpirationEnabled: true,
RegularUsersViewBlocked: false,
RoutingPeerDnsResolutionEnabled: ptr(false),
},
}
)
func TestAccounts_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Account{testAccount})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Accounts.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testAccount, ret[0])
})
}
func TestAccounts_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Accounts.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestAccounts_List_ConnErr(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
ret, err := c.Accounts.List(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "404")
assert.Empty(t, ret)
})
}
func TestAccounts_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiAccountsAccountIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, *req.Settings.RoutingPeerDnsResolutionEnabled)
retBytes, _ := json.Marshal(testAccount)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Accounts.Update(context.Background(), "Test", api.PutApiAccountsAccountIdJSONRequestBody{
Settings: api.AccountSettings{
RoutingPeerDnsResolutionEnabled: ptr(true),
},
})
require.NoError(t, err)
assert.Equal(t, testAccount, *ret)
})
}
func TestAccounts_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Accounts.Update(context.Background(), "Test", api.PutApiAccountsAccountIdJSONRequestBody{
Settings: api.AccountSettings{
RoutingPeerDnsResolutionEnabled: ptr(true),
},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestAccounts_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Accounts.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestAccounts_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Accounts.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestAccounts_Integration_List(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
accounts, err := c.Accounts.List(context.Background())
require.NoError(t, err)
assert.Len(t, accounts, 1)
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", accounts[0].Id)
assert.Equal(t, false, accounts[0].Settings.Extra.PeerApprovalEnabled)
})
}
func TestAccounts_Integration_Update(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
accounts, err := c.Accounts.List(context.Background())
require.NoError(t, err)
assert.Len(t, accounts, 1)
accounts[0].Settings.JwtAllowGroups = ptr([]string{"test"})
account, err := c.Accounts.Update(context.Background(), accounts[0].Id, api.AccountRequest{
Settings: accounts[0].Settings,
})
require.NoError(t, err)
assert.Equal(t, accounts[0].Id, account.Id)
assert.Equal(t, []string{"test"}, *account.Settings.JwtAllowGroups)
})
}
// Account deletion on MySQL and PostgreSQL databases causes unknown errors
// func TestAccounts_Integration_Delete(t *testing.T) {
// withBlackBoxServer(t, func(c *rest.Client) {
// accounts, err := c.Accounts.List(context.Background())
// require.NoError(t, err)
// assert.Len(t, accounts, 1)
// err = c.Accounts.Delete(context.Background(), accounts[0].Id)
// require.NoError(t, err)
// _, err = c.Accounts.List(context.Background())
// assert.Error(t, err)
// })
// }

View File

@@ -0,0 +1,172 @@
package rest
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/netbirdio/netbird/management/server/http/util"
)
// Client Management service HTTP REST API Client
type Client struct {
managementURL string
authHeader string
httpClient HttpClient
// Accounts NetBird account APIs
// see more: https://docs.netbird.io/api/resources/accounts
Accounts *AccountsAPI
// Users NetBird users APIs
// see more: https://docs.netbird.io/api/resources/users
Users *UsersAPI
// Tokens NetBird tokens APIs
// see more: https://docs.netbird.io/api/resources/tokens
Tokens *TokensAPI
// Peers NetBird peers APIs
// see more: https://docs.netbird.io/api/resources/peers
Peers *PeersAPI
// SetupKeys NetBird setup keys APIs
// see more: https://docs.netbird.io/api/resources/setup-keys
SetupKeys *SetupKeysAPI
// Groups NetBird groups APIs
// see more: https://docs.netbird.io/api/resources/groups
Groups *GroupsAPI
// Policies NetBird policies APIs
// see more: https://docs.netbird.io/api/resources/policies
Policies *PoliciesAPI
// PostureChecks NetBird posture checks APIs
// see more: https://docs.netbird.io/api/resources/posture-checks
PostureChecks *PostureChecksAPI
// Networks NetBird networks APIs
// see more: https://docs.netbird.io/api/resources/networks
Networks *NetworksAPI
// Routes NetBird routes APIs
// see more: https://docs.netbird.io/api/resources/routes
Routes *RoutesAPI
// DNS NetBird DNS APIs
// see more: https://docs.netbird.io/api/resources/routes
DNS *DNSAPI
// GeoLocation NetBird Geo Location APIs
// see more: https://docs.netbird.io/api/resources/geo-locations
GeoLocation *GeoLocationAPI
// Events NetBird Events APIs
// see more: https://docs.netbird.io/api/resources/events
Events *EventsAPI
}
// New initialize new Client instance using PAT token
func New(managementURL, token string) *Client {
return NewWithOptions(
WithManagementURL(managementURL),
WithPAT(token),
)
}
// NewWithBearerToken initialize new Client instance using Bearer token type
func NewWithBearerToken(managementURL, token string) *Client {
return NewWithOptions(
WithManagementURL(managementURL),
WithBearerToken(token),
)
}
// NewWithOptions initialize new Client instance with options
func NewWithOptions(opts ...option) *Client {
client := &Client{
httpClient: http.DefaultClient,
}
for _, option := range opts {
option(client)
}
client.initialize()
return client
}
func (c *Client) initialize() {
c.Accounts = &AccountsAPI{c}
c.Users = &UsersAPI{c}
c.Tokens = &TokensAPI{c}
c.Peers = &PeersAPI{c}
c.SetupKeys = &SetupKeysAPI{c}
c.Groups = &GroupsAPI{c}
c.Policies = &PoliciesAPI{c}
c.PostureChecks = &PostureChecksAPI{c}
c.Networks = &NetworksAPI{c}
c.Routes = &RoutesAPI{c}
c.DNS = &DNSAPI{c}
c.GeoLocation = &GeoLocationAPI{c}
c.Events = &EventsAPI{c}
}
// NewRequest creates and executes new management API request
func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Reader, query map[string]string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, c.managementURL+path, body)
if err != nil {
return nil, err
}
req.Header.Add("Authorization", c.authHeader)
req.Header.Add("Accept", "application/json")
if body != nil {
req.Header.Add("Content-Type", "application/json")
}
if len(query) != 0 {
q := req.URL.Query()
for k, v := range query {
q.Add(k, v)
}
req.URL.RawQuery = q.Encode()
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode > 299 {
parsedErr, pErr := parseResponse[util.ErrorResponse](resp)
if pErr != nil {
return nil, pErr
}
return nil, errors.New(parsedErr.Message)
}
return resp, nil
}
func parseResponse[T any](resp *http.Response) (T, error) {
var ret T
if resp.Body == nil {
return ret, fmt.Errorf("Body missing, HTTP Error code %d", resp.StatusCode)
}
bs, err := io.ReadAll(resp.Body)
if err != nil {
return ret, err
}
err = json.Unmarshal(bs, &ret)
if err != nil {
return ret, fmt.Errorf("Error code %d, error unmarshalling body: %w", resp.StatusCode, err)
}
return ret, nil
}

View File

@@ -0,0 +1,34 @@
//go:build integration
// +build integration
package rest_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
)
func withMockClient(callback func(*rest.Client, *http.ServeMux)) {
mux := &http.ServeMux{}
server := httptest.NewServer(mux)
defer server.Close()
c := rest.New(server.URL, "ABC")
callback(c, mux)
}
func ptr[T any, PT *T](x T) PT {
return &x
}
func withBlackBoxServer(t *testing.T, callback func(*rest.Client)) {
t.Helper()
handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../server/testdata/store.sql", nil, false)
server := httptest.NewServer(handler)
defer server.Close()
c := rest.New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp")
callback(c)
}

View File

@@ -0,0 +1,124 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// DNSAPI APIs for DNS Management, do not use directly
type DNSAPI struct {
c *Client
}
// ListNameserverGroups list all nameserver groups
// See more: https://docs.netbird.io/api/resources/dns#list-all-nameserver-groups
func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGroup, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.NameserverGroup](resp)
return ret, err
}
// GetNameserverGroup get nameserver group info
// See more: https://docs.netbird.io/api/resources/dns#retrieve-a-nameserver-group
func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID string) (*api.NameserverGroup, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NameserverGroup](resp)
return &ret, err
}
// CreateNameserverGroup create new nameserver group
// See more: https://docs.netbird.io/api/resources/dns#create-a-nameserver-group
func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiDnsNameserversJSONRequestBody) (*api.NameserverGroup, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NameserverGroup](resp)
return &ret, err
}
// UpdateNameserverGroup update nameserver group info
// See more: https://docs.netbird.io/api/resources/dns#update-a-nameserver-group
func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID string, request api.PutApiDnsNameserversNsgroupIdJSONRequestBody) (*api.NameserverGroup, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NameserverGroup](resp)
return &ret, err
}
// DeleteNameserverGroup delete nameserver group
// See more: https://docs.netbird.io/api/resources/dns#delete-a-nameserver-group
func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// GetSettings get DNS settings
// See more: https://docs.netbird.io/api/resources/dns#retrieve-dns-settings
func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/settings", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.DNSSettings](resp)
return &ret, err
}
// UpdateSettings update DNS settings
// See more: https://docs.netbird.io/api/resources/dns#update-dns-settings
func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettingsJSONRequestBody) (*api.DNSSettings, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.DNSSettings](resp)
return &ret, err
}

View File

@@ -0,0 +1,301 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testNameserverGroup = api.NameserverGroup{
Id: "Test",
Name: "wow",
}
testSettings = api.DNSSettings{
DisabledManagementGroups: []string{"gone"},
}
)
func TestDNSNameserverGroup_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.NameserverGroup{testNameserverGroup})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.ListNameserverGroups(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testNameserverGroup, ret[0])
})
}
func TestDNSNameserverGroup_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.ListNameserverGroups(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSNameserverGroup_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testNameserverGroup)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.GetNameserverGroup(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testNameserverGroup, *ret)
})
}
func TestDNSNameserverGroup_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.GetNameserverGroup(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSNameserverGroup_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiDnsNameserversJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testNameserverGroup)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.CreateNameserverGroup(context.Background(), api.PostApiDnsNameserversJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testNameserverGroup, *ret)
})
}
func TestDNSNameserverGroup_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.CreateNameserverGroup(context.Background(), api.PostApiDnsNameserversJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestDNSNameserverGroup_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testNameserverGroup)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.UpdateNameserverGroup(context.Background(), "Test", api.PutApiDnsNameserversNsgroupIdJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testNameserverGroup, *ret)
})
}
func TestDNSNameserverGroup_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.UpdateNameserverGroup(context.Background(), "Test", api.PutApiDnsNameserversNsgroupIdJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestDNSNameserverGroup_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.DNS.DeleteNameserverGroup(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestDNSNameserverGroup_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.DNS.DeleteNameserverGroup(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestDNSSettings_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testSettings)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.GetSettings(context.Background())
require.NoError(t, err)
assert.Equal(t, testSettings, *ret)
})
}
func TestDNSSettings_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.GetSettings(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSSettings_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiDnsSettingsJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, []string{"test"}, req.DisabledManagementGroups)
retBytes, _ := json.Marshal(testSettings)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.UpdateSettings(context.Background(), api.PutApiDnsSettingsJSONRequestBody{
DisabledManagementGroups: []string{"test"},
})
require.NoError(t, err)
assert.Equal(t, testSettings, *ret)
})
}
func TestDNSSettings_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.UpdateSettings(context.Background(), api.PutApiDnsSettingsJSONRequestBody{
DisabledManagementGroups: []string{"test"},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestDNS_Integration(t *testing.T) {
nsGroupReq := api.NameserverGroupRequest{
Description: "Test",
Enabled: true,
Domains: []string{},
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
Name: "test",
Nameservers: []api.Nameserver{
{
Ip: "8.8.8.8",
NsType: api.NameserverNsTypeUdp,
Port: 53,
},
},
Primary: true,
SearchDomainsEnabled: false,
}
withBlackBoxServer(t, func(c *rest.Client) {
// Create
nsGroup, err := c.DNS.CreateNameserverGroup(context.Background(), nsGroupReq)
require.NoError(t, err)
// List
nsGroups, err := c.DNS.ListNameserverGroups(context.Background())
require.NoError(t, err)
assert.Equal(t, *nsGroup, nsGroups[0])
// Update
nsGroupReq.Description = "TestUpdate"
nsGroup, err = c.DNS.UpdateNameserverGroup(context.Background(), nsGroup.Id, nsGroupReq)
require.NoError(t, err)
assert.Equal(t, "TestUpdate", nsGroup.Description)
// Delete
err = c.DNS.DeleteNameserverGroup(context.Background(), nsGroup.Id)
require.NoError(t, err)
// List again to ensure deletion
nsGroups, err = c.DNS.ListNameserverGroups(context.Background())
require.NoError(t, err)
assert.Len(t, nsGroups, 0)
})
}

View File

@@ -0,0 +1,26 @@
package rest
import (
"context"
"github.com/netbirdio/netbird/management/server/http/api"
)
// EventsAPI APIs for Events, do not use directly
type EventsAPI struct {
c *Client
}
// List list all events
// See more: https://docs.netbird.io/api/resources/events#list-all-events
func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Event](resp)
return ret, err
}

View File

@@ -0,0 +1,70 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testEvent = api.Event{
Activity: "AccountCreate",
ActivityCode: api.EventActivityCodeAccountCreate,
}
)
func TestEvents_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Event{testEvent})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Events.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testEvent, ret[0])
})
}
func TestEvents_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Events.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestEvents_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
// Do something that would trigger any event
_, err := c.SetupKeys.Create(context.Background(), api.CreateSetupKeyRequest{
Ephemeral: ptr(true),
Name: "TestSetupKey",
Type: "reusable",
})
require.NoError(t, err)
events, err := c.Events.List(context.Background())
require.NoError(t, err)
assert.NotEmpty(t, events)
})
}

View File

@@ -0,0 +1,40 @@
package rest
import (
"context"
"github.com/netbirdio/netbird/management/server/http/api"
)
// GeoLocationAPI APIs for Geo-Location, do not use directly
type GeoLocationAPI struct {
c *Client
}
// ListCountries list all country codes
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-country-codes
func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Country](resp)
return ret, err
}
// ListCountryCities Get a list of all English city names for a given country code
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-city-names-by-country
func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode string) ([]api.City, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.City](resp)
return ret, err
}

View File

@@ -0,0 +1,101 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testCountry = api.Country{
CountryCode: "DE",
CountryName: "Germany",
}
testCity = api.City{
CityName: "Berlin",
GeonameId: 2950158,
}
)
func TestGeo_ListCountries_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Country{testCountry})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GeoLocation.ListCountries(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testCountry, ret[0])
})
}
func TestGeo_ListCountries_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GeoLocation.ListCountries(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestGeo_ListCountryCities_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.City{testCity})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GeoLocation.ListCountryCities(context.Background(), "Test")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testCity, ret[0])
})
}
func TestGeo_ListCountryCities_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GeoLocation.ListCountryCities(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestGeo_Integration(t *testing.T) {
// Blackbox is initialized with empty GeoLocations
withBlackBoxServer(t, func(c *rest.Client) {
countries, err := c.GeoLocation.ListCountries(context.Background())
require.NoError(t, err)
assert.Empty(t, countries)
cities, err := c.GeoLocation.ListCountryCities(context.Background(), "DE")
require.NoError(t, err)
assert.Empty(t, cities)
})
}

View File

@@ -0,0 +1,92 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// GroupsAPI APIs for Groups, do not use directly
type GroupsAPI struct {
c *Client
}
// List list all groups
// See more: https://docs.netbird.io/api/resources/groups#list-all-groups
func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Group](resp)
return ret, err
}
// Get get group info
// See more: https://docs.netbird.io/api/resources/groups#retrieve-a-group
func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups/"+groupID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Group](resp)
return &ret, err
}
// Create create new group
// See more: https://docs.netbird.io/api/resources/groups#create-a-group
func (a *GroupsAPI) Create(ctx context.Context, request api.PostApiGroupsJSONRequestBody) (*api.Group, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/groups", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Group](resp)
return &ret, err
}
// Update update group info
// See more: https://docs.netbird.io/api/resources/groups#update-a-group
func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutApiGroupsGroupIdJSONRequestBody) (*api.Group, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/groups/"+groupID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Group](resp)
return &ret, err
}
// Delete delete group
// See more: https://docs.netbird.io/api/resources/groups#delete-a-group
func (a *GroupsAPI) Delete(ctx context.Context, groupID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/groups/"+groupID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,215 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testGroup = api.Group{
Id: "Test",
Name: "wow",
PeersCount: 0,
}
)
func TestGroups_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Group{testGroup})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testGroup, ret[0])
})
}
func TestGroups_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestGroups_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testGroup)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testGroup, *ret)
})
}
func TestGroups_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestGroups_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiGroupsJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testGroup)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.Create(context.Background(), api.PostApiGroupsJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testGroup, *ret)
})
}
func TestGroups_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.Create(context.Background(), api.PostApiGroupsJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestGroups_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiGroupsGroupIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testGroup)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.Update(context.Background(), "Test", api.PutApiGroupsGroupIdJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testGroup, *ret)
})
}
func TestGroups_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.Update(context.Background(), "Test", api.PutApiGroupsGroupIdJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestGroups_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Groups.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestGroups_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Groups.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestGroups_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
groups, err := c.Groups.List(context.Background())
require.NoError(t, err)
assert.Len(t, groups, 1)
group, err := c.Groups.Create(context.Background(), api.GroupRequest{
Name: "Test",
})
require.NoError(t, err)
assert.Equal(t, "Test", group.Name)
assert.NotEmpty(t, group.Id)
group, err = c.Groups.Update(context.Background(), group.Id, api.GroupRequest{
Name: "Testnt",
})
require.NoError(t, err)
assert.Equal(t, "Testnt", group.Name)
err = c.Groups.Delete(context.Background(), group.Id)
require.NoError(t, err)
groups, err = c.Groups.List(context.Background())
require.NoError(t, err)
assert.Len(t, groups, 1)
})
}

View File

@@ -0,0 +1,48 @@
package rest
import (
"net/http"
"net/url"
)
// Impersonate returns a Client impersonated for a specific account
func (c *Client) Impersonate(account string) *Client {
client := NewWithOptions(
WithManagementURL(c.managementURL),
WithAuthHeader(c.authHeader),
WithHttpClient(newImpersonatedHttpClient(c, account)),
)
return client
}
type impersonatedHttpClient struct {
baseClient HttpClient
account string
}
func newImpersonatedHttpClient(c *Client, account string) *impersonatedHttpClient {
if hc, ok := c.httpClient.(*impersonatedHttpClient); ok {
hc.account = account
return hc
}
return &impersonatedHttpClient{
baseClient: c.httpClient,
account: account,
}
}
func (c *impersonatedHttpClient) Do(req *http.Request) (*http.Response, error) {
parsedURL, err := url.Parse(req.URL.String())
if err != nil {
return nil, err
}
query := parsedURL.Query()
query.Set("account", c.account)
parsedURL.RawQuery = query.Encode()
req.URL = parsedURL
return c.baseClient.Do(req)
}

View File

@@ -0,0 +1,77 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
)
var (
testImpersonatedAccount = api.Account{
Id: "ImpersonatedTest",
Settings: api.AccountSettings{
Extra: &api.AccountExtraSettings{
PeerApprovalEnabled: false,
},
GroupsPropagationEnabled: ptr(true),
JwtGroupsEnabled: ptr(false),
PeerInactivityExpiration: 7,
PeerInactivityExpirationEnabled: true,
PeerLoginExpiration: 24,
PeerLoginExpirationEnabled: true,
RegularUsersViewBlocked: false,
RoutingPeerDnsResolutionEnabled: ptr(false),
},
}
)
func TestImpersonation_Peers_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
impersonatedClient := c.Impersonate(testImpersonatedAccount.Id)
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, r.URL.Query().Get("account"), testImpersonatedAccount.Id)
retBytes, _ := json.Marshal([]api.Peer{testPeer})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := impersonatedClient.Peers.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testPeer, ret[0])
})
}
func TestImpersonation_Change_Account(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
impersonatedClient := c.Impersonate(testImpersonatedAccount.Id)
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, r.URL.Query().Get("account"), testImpersonatedAccount.Id)
retBytes, _ := json.Marshal([]api.Peer{testPeer})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
_, err := impersonatedClient.Peers.List(context.Background())
require.NoError(t, err)
impersonatedClient = impersonatedClient.Impersonate("another-test-account")
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, r.URL.Query().Get("account"), "another-test-account")
retBytes, _ := json.Marshal(testPeer)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
_, err = impersonatedClient.Peers.Get(context.Background(), "Test")
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,276 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// NetworksAPI APIs for Networks, do not use directly
type NetworksAPI struct {
c *Client
}
// List list all networks
// See more: https://docs.netbird.io/api/resources/networks#list-all-networks
func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Network](resp)
return ret, err
}
// Get get network info
// See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network
func (a *NetworksAPI) Get(ctx context.Context, networkID string) (*api.Network, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+networkID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Network](resp)
return &ret, err
}
// Create create new network
// See more: https://docs.netbird.io/api/resources/networks#create-a-network
func (a *NetworksAPI) Create(ctx context.Context, request api.PostApiNetworksJSONRequestBody) (*api.Network, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Network](resp)
return &ret, err
}
// Update update network
// See more: https://docs.netbird.io/api/resources/networks#update-a-network
func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api.PutApiNetworksNetworkIdJSONRequestBody) (*api.Network, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+networkID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Network](resp)
return &ret, err
}
// Delete delete network
// See more: https://docs.netbird.io/api/resources/networks#delete-a-network
func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+networkID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// NetworkResourcesAPI APIs for Network Resources, do not use directly
type NetworkResourcesAPI struct {
c *Client
networkID string
}
// Resources APIs for network resources
func (a *NetworksAPI) Resources(networkID string) *NetworkResourcesAPI {
return &NetworkResourcesAPI{
c: a.c,
networkID: networkID,
}
}
// List list all resources in networks
// See more: https://docs.netbird.io/api/resources/networks#list-all-network-resources
func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.NetworkResource](resp)
return ret, err
}
// Get get network resource info
// See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network-resource
func (a *NetworkResourcesAPI) Get(ctx context.Context, networkResourceID string) (*api.NetworkResource, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkResource](resp)
return &ret, err
}
// Create create new network resource
// See more: https://docs.netbird.io/api/resources/networks#create-a-network-resource
func (a *NetworkResourcesAPI) Create(ctx context.Context, request api.PostApiNetworksNetworkIdResourcesJSONRequestBody) (*api.NetworkResource, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/resources", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkResource](resp)
return &ret, err
}
// Update update network resource
// See more: https://docs.netbird.io/api/resources/networks#update-a-network-resource
func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID string, request api.PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody) (*api.NetworkResource, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkResource](resp)
return &ret, err
}
// Delete delete network resource
// See more: https://docs.netbird.io/api/resources/networks#delete-a-network-resource
func (a *NetworkResourcesAPI) Delete(ctx context.Context, networkResourceID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// NetworkRoutersAPI APIs for Network Routers, do not use directly
type NetworkRoutersAPI struct {
c *Client
networkID string
}
// Routers APIs for network routers
func (a *NetworksAPI) Routers(networkID string) *NetworkRoutersAPI {
return &NetworkRoutersAPI{
c: a.c,
networkID: networkID,
}
}
// List list all routers in networks
// See more: https://docs.netbird.io/api/routers/networks#list-all-network-routers
func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.NetworkRouter](resp)
return ret, err
}
// Get get network router info
// See more: https://docs.netbird.io/api/routers/networks#retrieve-a-network-router
func (a *NetworkRoutersAPI) Get(ctx context.Context, networkRouterID string) (*api.NetworkRouter, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkRouter](resp)
return &ret, err
}
// Create create new network router
// See more: https://docs.netbird.io/api/routers/networks#create-a-network-router
func (a *NetworkRoutersAPI) Create(ctx context.Context, request api.PostApiNetworksNetworkIdRoutersJSONRequestBody) (*api.NetworkRouter, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/routers", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkRouter](resp)
return &ret, err
}
// Update update network router
// See more: https://docs.netbird.io/api/routers/networks#update-a-network-router
func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string, request api.PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody) (*api.NetworkRouter, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkRouter](resp)
return &ret, err
}
// Delete delete network router
// See more: https://docs.netbird.io/api/routers/networks#delete-a-network-router
func (a *NetworkRoutersAPI) Delete(ctx context.Context, networkRouterID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,594 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testNetwork = api.Network{
Id: "Test",
Name: "wow",
}
testNetworkResource = api.NetworkResource{
Description: ptr("meaw"),
Id: "awa",
}
testNetworkRouter = api.NetworkRouter{
Id: "ouch",
}
)
func TestNetworks_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Network{testNetwork})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testNetwork, ret[0])
})
}
func TestNetworks_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworks_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testNetwork)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testNetwork, *ret)
})
}
func TestNetworks_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworks_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiNetworksJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testNetwork)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Create(context.Background(), api.PostApiNetworksJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testNetwork, *ret)
})
}
func TestNetworks_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Create(context.Background(), api.PostApiNetworksJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestNetworks_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiNetworksNetworkIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testNetwork)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Update(context.Background(), "Test", api.PutApiNetworksNetworkIdJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testNetwork, *ret)
})
}
func TestNetworks_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Update(context.Background(), "Test", api.PutApiNetworksNetworkIdJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestNetworks_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Networks.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestNetworks_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Networks.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestNetworks_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
network, err := c.Networks.Create(context.Background(), api.NetworkRequest{
Description: ptr("TestNetwork"),
Name: "Test",
})
assert.NoError(t, err)
assert.Equal(t, "Test", network.Name)
networks, err := c.Networks.List(context.Background())
assert.NoError(t, err)
assert.Empty(t, networks)
network, err = c.Networks.Update(context.Background(), "TestID", api.NetworkRequest{
Description: ptr("TestNetwork?"),
Name: "Test",
})
assert.NoError(t, err)
assert.Equal(t, "TestNetwork?", *network.Description)
err = c.Networks.Delete(context.Background(), "TestID")
assert.NoError(t, err)
})
}
func TestNetworkResources_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.NetworkResource{testNetworkResource})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testNetworkResource, ret[0])
})
}
func TestNetworkResources_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworkResources_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testNetworkResource)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testNetworkResource, *ret)
})
}
func TestNetworkResources_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworkResources_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiNetworksNetworkIdResourcesJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testNetworkResource)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").Create(context.Background(), api.PostApiNetworksNetworkIdResourcesJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testNetworkResource, *ret)
})
}
func TestNetworkResources_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").Create(context.Background(), api.PostApiNetworksNetworkIdResourcesJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestNetworkResources_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testNetworkResource)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").Update(context.Background(), "Test", api.PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testNetworkResource, *ret)
})
}
func TestNetworkResources_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").Update(context.Background(), "Test", api.PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestNetworkResources_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Networks.Resources("Meow").Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestNetworkResources_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Networks.Resources("Meow").Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestNetworkResources_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
_, err := c.Networks.Resources("TestNetwork").Create(context.Background(), api.NetworkResourceRequest{
Address: "test.com",
Description: ptr("Description"),
Enabled: false,
Groups: []string{"test"},
Name: "test",
})
assert.NoError(t, err)
_, err = c.Networks.Resources("TestNetwork").List(context.Background())
assert.NoError(t, err)
_, err = c.Networks.Resources("TestNetwork").Get(context.Background(), "TestResource")
assert.NoError(t, err)
_, err = c.Networks.Resources("TestNetwork").Update(context.Background(), "TestResource", api.NetworkResourceRequest{
Address: "testnt.com",
})
assert.NoError(t, err)
err = c.Networks.Resources("TestNetwork").Delete(context.Background(), "TestResource")
assert.NoError(t, err)
})
}
func TestNetworkRouters_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.NetworkRouter{testNetworkRouter})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testNetworkRouter, ret[0])
})
}
func TestNetworkRouters_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworkRouters_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testNetworkRouter)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testNetworkRouter, *ret)
})
}
func TestNetworkRouters_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworkRouters_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiNetworksNetworkIdRoutersJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "test", *req.Peer)
retBytes, _ := json.Marshal(testNetworkRouter)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").Create(context.Background(), api.PostApiNetworksNetworkIdRoutersJSONRequestBody{
Peer: ptr("test"),
})
require.NoError(t, err)
assert.Equal(t, testNetworkRouter, *ret)
})
}
func TestNetworkRouters_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").Create(context.Background(), api.PostApiNetworksNetworkIdRoutersJSONRequestBody{
Peer: ptr("test"),
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestNetworkRouters_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "test", *req.Peer)
retBytes, _ := json.Marshal(testNetworkRouter)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").Update(context.Background(), "Test", api.PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody{
Peer: ptr("test"),
})
require.NoError(t, err)
assert.Equal(t, testNetworkRouter, *ret)
})
}
func TestNetworkRouters_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").Update(context.Background(), "Test", api.PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody{
Peer: ptr("test"),
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestNetworkRouters_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Networks.Routers("Meow").Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestNetworkRouters_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Networks.Routers("Meow").Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestNetworkRouters_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
_, err := c.Networks.Routers("TestNetwork").Create(context.Background(), api.NetworkRouterRequest{
Enabled: false,
Masquerade: false,
Metric: 9999,
PeerGroups: ptr([]string{"test"}),
})
assert.NoError(t, err)
_, err = c.Networks.Routers("TestNetwork").List(context.Background())
assert.NoError(t, err)
_, err = c.Networks.Routers("TestNetwork").Get(context.Background(), "TestRouter")
assert.NoError(t, err)
_, err = c.Networks.Routers("TestNetwork").Update(context.Background(), "TestRouter", api.NetworkRouterRequest{
Enabled: true,
})
assert.NoError(t, err)
err = c.Networks.Routers("TestNetwork").Delete(context.Background(), "TestRouter")
assert.NoError(t, err)
})
}

View File

@@ -0,0 +1,44 @@
package rest
import "net/http"
// option modifier for creation of Client
type option func(*Client)
// HTTPClient interface for HTTP client
type HttpClient interface {
Do(req *http.Request) (*http.Response, error)
}
// WithHTTPClient overrides HTTPClient used
func WithHttpClient(client HttpClient) option {
return func(c *Client) {
c.httpClient = client
}
}
// WithBearerToken uses provided bearer token acquired from SSO for authentication
func WithBearerToken(token string) option {
return WithAuthHeader("Bearer " + token)
}
// WithPAT uses provided Personal Access Token
// (created from NetBird Management Dashboard) for authentication
func WithPAT(token string) option {
return WithAuthHeader("Token " + token)
}
// WithManagementURL overrides target NetBird Management server
func WithManagementURL(url string) option {
return func(c *Client) {
c.managementURL = url
}
}
// WithAuthHeader overrides auth header completely, this should generally not be used
// and WithBearerToken or WithPAT should be used instead
func WithAuthHeader(value string) option {
return func(c *Client) {
c.authHeader = value
}
}

View File

@@ -0,0 +1,108 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// PeersAPI APIs for peers, do not use directly
type PeersAPI struct {
c *Client
}
// PeersListOption options for Peers List API
type PeersListOption func() (string, string)
func PeerNameFilter(name string) PeersListOption {
return func() (string, string) {
return "name", name
}
}
func PeerIPFilter(ip string) PeersListOption {
return func() (string, string) {
return "ip", ip
}
}
// List list all peers
// See more: https://docs.netbird.io/api/resources/peers#list-all-peers
func (a *PeersAPI) List(ctx context.Context, opts ...PeersListOption) ([]api.Peer, error) {
query := make(map[string]string)
for _, o := range opts {
k, v := o()
query[k] = v
}
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers", nil, query)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Peer](resp)
return ret, err
}
// Get retrieve a peer
// See more: https://docs.netbird.io/api/resources/peers#retrieve-a-peer
func (a *PeersAPI) Get(ctx context.Context, peerID string) (*api.Peer, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Peer](resp)
return &ret, err
}
// Update update information for a peer
// See more: https://docs.netbird.io/api/resources/peers#update-a-peer
func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApiPeersPeerIdJSONRequestBody) (*api.Peer, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/peers/"+peerID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Peer](resp)
return &ret, err
}
// Delete delete a peer
// See more: https://docs.netbird.io/api/resources/peers#delete-a-peer
func (a *PeersAPI) Delete(ctx context.Context, peerID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+peerID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// ListAccessiblePeers list all peers that the specified peer can connect to within the network
// See more: https://docs.netbird.io/api/resources/peers#list-accessible-peers
func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]api.Peer, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID+"/accessible-peers", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Peer](resp)
return ret, err
}

View File

@@ -0,0 +1,212 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testPeer = api.Peer{
ApprovalRequired: false,
Connected: false,
ConnectionIp: "127.0.0.1",
DnsLabel: "test",
Id: "Test",
}
)
func TestPeers_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Peer{testPeer})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testPeer, ret[0])
})
}
func TestPeers_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeers_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testPeer)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testPeer, *ret)
})
}
func TestPeers_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeers_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiPeersPeerIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, req.InactivityExpirationEnabled)
retBytes, _ := json.Marshal(testPeer)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Update(context.Background(), "Test", api.PutApiPeersPeerIdJSONRequestBody{
InactivityExpirationEnabled: true,
})
require.NoError(t, err)
assert.Equal(t, testPeer, *ret)
})
}
func TestPeers_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Update(context.Background(), "Test", api.PutApiPeersPeerIdJSONRequestBody{
InactivityExpirationEnabled: false,
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPeers_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Peers.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestPeers_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Peers.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestPeers_ListAccessiblePeers_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Peer{testPeer})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.ListAccessiblePeers(context.Background(), "Test")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testPeer, ret[0])
})
}
func TestPeers_ListAccessiblePeers_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.ListAccessiblePeers(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeers_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
peers, err := c.Peers.List(context.Background())
require.NoError(t, err)
require.NotEmpty(t, peers)
filteredPeers, err := c.Peers.List(context.Background(), rest.PeerIPFilter("192.168.10.0"))
require.NoError(t, err)
require.Empty(t, filteredPeers)
peer, err := c.Peers.Get(context.Background(), peers[0].Id)
require.NoError(t, err)
assert.Equal(t, peers[0].Id, peer.Id)
peer, err = c.Peers.Update(context.Background(), peer.Id, api.PeerRequest{
LoginExpirationEnabled: true,
Name: "Test",
SshEnabled: false,
ApprovalRequired: ptr(false),
InactivityExpirationEnabled: false,
})
require.NoError(t, err)
assert.Equal(t, true, peer.LoginExpirationEnabled)
accessiblePeers, err := c.Peers.ListAccessiblePeers(context.Background(), peer.Id)
require.NoError(t, err)
assert.Empty(t, accessiblePeers)
err = c.Peers.Delete(context.Background(), peer.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,96 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// PoliciesAPI APIs for Policies, do not use directly
type PoliciesAPI struct {
c *Client
}
// List list all policies
// See more: https://docs.netbird.io/api/resources/policies#list-all-policies
func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) {
path := "/api/policies"
resp, err := a.c.NewRequest(ctx, "GET", path, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Policy](resp)
return ret, err
}
// Get get policy info
// See more: https://docs.netbird.io/api/resources/policies#retrieve-a-policy
func (a *PoliciesAPI) Get(ctx context.Context, policyID string) (*api.Policy, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/policies/"+policyID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Policy](resp)
return &ret, err
}
// Create create new policy
// See more: https://docs.netbird.io/api/resources/policies#create-a-policy
func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSONRequestBody) (*api.Policy, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/policies", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Policy](resp)
return &ret, err
}
// Update update policy info
// See more: https://docs.netbird.io/api/resources/policies#update-a-policy
func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.PutApiPoliciesPolicyIdJSONRequestBody) (*api.Policy, error) {
path := "/api/policies/" + policyID
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", path, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Policy](resp)
return &ret, err
}
// Delete delete policy
// See more: https://docs.netbird.io/api/resources/policies#delete-a-policy
func (a *PoliciesAPI) Delete(ctx context.Context, policyID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/policies/"+policyID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,241 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testPolicy = api.Policy{
Name: "wow",
Id: ptr("Test"),
Enabled: false,
}
)
func TestPolicies_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Policy{testPolicy})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testPolicy, ret[0])
})
}
func TestPolicies_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPolicies_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testPolicy)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testPolicy, *ret)
})
}
func TestPolicies_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPolicies_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiPoliciesPolicyIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testPolicy)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.Create(context.Background(), api.PostApiPoliciesJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testPolicy, *ret)
})
}
func TestPolicies_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.Create(context.Background(), api.PostApiPoliciesJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPolicies_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiPoliciesPolicyIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testPolicy)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.Update(context.Background(), "Test", api.PutApiPoliciesPolicyIdJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testPolicy, *ret)
})
}
func TestPolicies_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.Update(context.Background(), "Test", api.PutApiPoliciesPolicyIdJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPolicies_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Policies.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestPolicies_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Policies.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestPolicies_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
policies, err := c.Policies.List(context.Background())
require.NoError(t, err)
require.NotEmpty(t, policies)
policy, err := c.Policies.Get(context.Background(), *policies[0].Id)
require.NoError(t, err)
assert.Equal(t, *policies[0].Id, *policy.Id)
policy, err = c.Policies.Update(context.Background(), *policy.Id, api.PolicyCreate{
Description: ptr("Test Policy"),
Enabled: false,
Name: "Test",
Rules: []api.PolicyRuleUpdate{
{
Action: api.PolicyRuleUpdateAction(policy.Rules[0].Action),
Bidirectional: true,
Description: ptr("Test Policy"),
Sources: ptr([]string{(*policy.Rules[0].Sources)[0].Id}),
Destinations: ptr([]string{(*policy.Rules[0].Destinations)[0].Id}),
Enabled: false,
Protocol: api.PolicyRuleUpdateProtocolAll,
},
},
SourcePostureChecks: nil,
})
require.NoError(t, err)
assert.Equal(t, "Test Policy", *policy.Rules[0].Description)
policy, err = c.Policies.Create(context.Background(), api.PolicyUpdate{
Description: ptr("Test Policy 2"),
Enabled: false,
Name: "Test",
Rules: []api.PolicyRuleUpdate{
{
Action: api.PolicyRuleUpdateAction(policy.Rules[0].Action),
Bidirectional: true,
Description: ptr("Test Policy 2"),
Sources: ptr([]string{(*policy.Rules[0].Sources)[0].Id}),
Destinations: ptr([]string{(*policy.Rules[0].Destinations)[0].Id}),
Enabled: false,
Protocol: api.PolicyRuleUpdateProtocolAll,
},
},
SourcePostureChecks: nil,
})
require.NoError(t, err)
err = c.Policies.Delete(context.Background(), *policy.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,92 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// PostureChecksAPI APIs for PostureChecks, do not use directly
type PostureChecksAPI struct {
c *Client
}
// List list all posture checks
// See more: https://docs.netbird.io/api/resources/posture-checks#list-all-posture-checks
func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.PostureCheck](resp)
return ret, err
}
// Get get posture check info
// See more: https://docs.netbird.io/api/resources/posture-checks#retrieve-a-posture-check
func (a *PostureChecksAPI) Get(ctx context.Context, postureCheckID string) (*api.PostureCheck, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks/"+postureCheckID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.PostureCheck](resp)
return &ret, err
}
// Create create new posture check
// See more: https://docs.netbird.io/api/resources/posture-checks#create-a-posture-check
func (a *PostureChecksAPI) Create(ctx context.Context, request api.PostApiPostureChecksJSONRequestBody) (*api.PostureCheck, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/posture-checks", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.PostureCheck](resp)
return &ret, err
}
// Update update posture check info
// See more: https://docs.netbird.io/api/resources/posture-checks#update-a-posture-check
func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, request api.PutApiPostureChecksPostureCheckIdJSONRequestBody) (*api.PostureCheck, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/posture-checks/"+postureCheckID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.PostureCheck](resp)
return &ret, err
}
// Delete delete posture check
// See more: https://docs.netbird.io/api/resources/posture-checks#delete-a-posture-check
func (a *PostureChecksAPI) Delete(ctx context.Context, postureCheckID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/posture-checks/"+postureCheckID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,233 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testPostureCheck = api.PostureCheck{
Id: "Test",
Name: "wow",
}
)
func TestPostureChecks_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.PostureCheck{testPostureCheck})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testPostureCheck, ret[0])
})
}
func TestPostureChecks_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPostureChecks_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testPostureCheck)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testPostureCheck, *ret)
})
}
func TestPostureChecks_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPostureChecks_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostureCheckUpdate
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testPostureCheck)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testPostureCheck, *ret)
})
}
func TestPostureChecks_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPostureChecks_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostureCheckUpdate
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testPostureCheck)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.Update(context.Background(), "Test", api.PostureCheckUpdate{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testPostureCheck, *ret)
})
}
func TestPostureChecks_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.Update(context.Background(), "Test", api.PostureCheckUpdate{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPostureChecks_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.PostureChecks.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestPostureChecks_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.PostureChecks.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestPostureChecks_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
check, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{
Name: "Test",
Description: "Testing",
Checks: &api.Checks{
OsVersionCheck: &api.OSVersionCheck{
Windows: &api.MinKernelVersionCheck{
MinKernelVersion: "0.0.0",
},
},
},
})
require.NoError(t, err)
assert.Equal(t, "Test", check.Name)
checks, err := c.PostureChecks.List(context.Background())
require.NoError(t, err)
assert.Len(t, checks, 1)
check, err = c.PostureChecks.Update(context.Background(), check.Id, api.PostureCheckUpdate{
Name: "Tests",
Description: "Testings",
Checks: &api.Checks{
GeoLocationCheck: &api.GeoLocationCheck{
Action: api.GeoLocationCheckActionAllow, Locations: []api.Location{
{
CityName: ptr("Cairo"),
CountryCode: "EG",
},
},
},
},
})
require.NoError(t, err)
assert.Equal(t, "Testings", *check.Description)
check, err = c.PostureChecks.Get(context.Background(), check.Id)
require.NoError(t, err)
assert.Equal(t, "Tests", check.Name)
err = c.PostureChecks.Delete(context.Background(), check.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,92 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// RoutesAPI APIs for Routes, do not use directly
type RoutesAPI struct {
c *Client
}
// List list all routes
// See more: https://docs.netbird.io/api/resources/routes#list-all-routes
func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/routes", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Route](resp)
return ret, err
}
// Get get route info
// See more: https://docs.netbird.io/api/resources/routes#retrieve-a-route
func (a *RoutesAPI) Get(ctx context.Context, routeID string) (*api.Route, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/routes/"+routeID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Route](resp)
return &ret, err
}
// Create create new route
// See more: https://docs.netbird.io/api/resources/routes#create-a-route
func (a *RoutesAPI) Create(ctx context.Context, request api.PostApiRoutesJSONRequestBody) (*api.Route, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/routes", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Route](resp)
return &ret, err
}
// Update update route info
// See more: https://docs.netbird.io/api/resources/routes#update-a-route
func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutApiRoutesRouteIdJSONRequestBody) (*api.Route, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/routes/"+routeID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Route](resp)
return &ret, err
}
// Delete delete route
// See more: https://docs.netbird.io/api/resources/routes#delete-a-route
func (a *RoutesAPI) Delete(ctx context.Context, routeID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/routes/"+routeID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,231 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testRoute = api.Route{
Id: "Test",
Domains: ptr([]string{"google.com"}),
}
)
func TestRoutes_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Route{testRoute})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testRoute, ret[0])
})
}
func TestRoutes_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestRoutes_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testRoute)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testRoute, *ret)
})
}
func TestRoutes_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestRoutes_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiRoutesJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "meow", req.Description)
retBytes, _ := json.Marshal(testRoute)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.Create(context.Background(), api.PostApiRoutesJSONRequestBody{
Description: "meow",
})
require.NoError(t, err)
assert.Equal(t, testRoute, *ret)
})
}
func TestRoutes_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.Create(context.Background(), api.PostApiRoutesJSONRequestBody{
Description: "meow",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestRoutes_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiRoutesRouteIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "meow", req.Description)
retBytes, _ := json.Marshal(testRoute)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.Update(context.Background(), "Test", api.PutApiRoutesRouteIdJSONRequestBody{
Description: "meow",
})
require.NoError(t, err)
assert.Equal(t, testRoute, *ret)
})
}
func TestRoutes_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.Update(context.Background(), "Test", api.PutApiRoutesRouteIdJSONRequestBody{
Description: "meow",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestRoutes_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Routes.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestRoutes_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Routes.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestRoutes_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
route, err := c.Routes.Create(context.Background(), api.RouteRequest{
Description: "Meow",
Enabled: false,
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
PeerGroups: ptr([]string{"cs1tnh0hhcjnqoiuebeg"}),
Domains: ptr([]string{"google.com"}),
Masquerade: true,
Metric: 9999,
KeepRoute: false,
NetworkId: "Test",
})
require.NoError(t, err)
assert.Equal(t, "Test", route.NetworkId)
routes, err := c.Routes.List(context.Background())
require.NoError(t, err)
assert.Len(t, routes, 1)
route, err = c.Routes.Update(context.Background(), route.Id, api.RouteRequest{
Description: "Testings",
Enabled: false,
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
PeerGroups: ptr([]string{"cs1tnh0hhcjnqoiuebeg"}),
Domains: ptr([]string{"google.com"}),
Masquerade: true,
Metric: 9999,
KeepRoute: false,
NetworkId: "Tests",
})
require.NoError(t, err)
assert.Equal(t, "Testings", route.Description)
route, err = c.Routes.Get(context.Background(), route.Id)
require.NoError(t, err)
assert.Equal(t, "Tests", route.NetworkId)
err = c.Routes.Delete(context.Background(), route.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,94 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// SetupKeysAPI APIs for Setup keys, do not use directly
type SetupKeysAPI struct {
c *Client
}
// List list all setup keys
// See more: https://docs.netbird.io/api/resources/setup-keys#list-all-setup-keys
func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.SetupKey](resp)
return ret, err
}
// Get get setup key info
// See more: https://docs.netbird.io/api/resources/setup-keys#retrieve-a-setup-key
func (a *SetupKeysAPI) Get(ctx context.Context, setupKeyID string) (*api.SetupKey, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys/"+setupKeyID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.SetupKey](resp)
return &ret, err
}
// Create generate new Setup Key
// See more: https://docs.netbird.io/api/resources/setup-keys#create-a-setup-key
func (a *SetupKeysAPI) Create(ctx context.Context, request api.PostApiSetupKeysJSONRequestBody) (*api.SetupKeyClear, error) {
path := "/api/setup-keys"
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", path, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.SetupKeyClear](resp)
return &ret, err
}
// Update generate new Setup Key
// See more: https://docs.netbird.io/api/resources/setup-keys#update-a-setup-key
func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request api.PutApiSetupKeysKeyIdJSONRequestBody) (*api.SetupKey, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/setup-keys/"+setupKeyID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.SetupKey](resp)
return &ret, err
}
// Delete delete setup key
// See more: https://docs.netbird.io/api/resources/setup-keys#delete-a-setup-key
func (a *SetupKeysAPI) Delete(ctx context.Context, setupKeyID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/setup-keys/"+setupKeyID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,232 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testSetupKey = api.SetupKey{
Id: "Test",
Name: "wow",
AutoGroups: []string{"meow"},
Ephemeral: true,
}
testSteupKeyGenerated = api.SetupKeyClear{
Id: "Test",
Name: "wow",
AutoGroups: []string{"meow"},
Ephemeral: true,
Key: "shhh",
}
)
func TestSetupKeys_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.SetupKey{testSetupKey})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testSetupKey, ret[0])
})
}
func TestSetupKeys_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestSetupKeys_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testSetupKey)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testSetupKey, *ret)
})
}
func TestSetupKeys_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestSetupKeys_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiSetupKeysJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, 5, req.ExpiresIn)
retBytes, _ := json.Marshal(testSteupKeyGenerated)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.Create(context.Background(), api.PostApiSetupKeysJSONRequestBody{
ExpiresIn: 5,
})
require.NoError(t, err)
assert.Equal(t, testSteupKeyGenerated, *ret)
})
}
func TestSetupKeys_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.Create(context.Background(), api.PostApiSetupKeysJSONRequestBody{
ExpiresIn: 5,
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestSetupKeys_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiSetupKeysKeyIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, req.Revoked)
retBytes, _ := json.Marshal(testSetupKey)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.Update(context.Background(), "Test", api.PutApiSetupKeysKeyIdJSONRequestBody{
Revoked: true,
})
require.NoError(t, err)
assert.Equal(t, testSetupKey, *ret)
})
}
func TestSetupKeys_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.Update(context.Background(), "Test", api.PutApiSetupKeysKeyIdJSONRequestBody{
Revoked: true,
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestSetupKeys_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.SetupKeys.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestSetupKeys_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.SetupKeys.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestSetupKeys_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
group, err := c.Groups.Create(context.Background(), api.GroupRequest{
Name: "Test",
})
require.NoError(t, err)
skClear, err := c.SetupKeys.Create(context.Background(), api.CreateSetupKeyRequest{
AutoGroups: []string{group.Id},
Ephemeral: ptr(false),
Name: "test",
Type: "reusable",
})
require.NoError(t, err)
assert.Equal(t, true, skClear.Valid)
keys, err := c.SetupKeys.List(context.Background())
require.NoError(t, err)
assert.Len(t, keys, 2)
sk, err := c.SetupKeys.Update(context.Background(), skClear.Id, api.SetupKeyRequest{
Revoked: true,
AutoGroups: []string{group.Id},
})
require.NoError(t, err)
sk, err = c.SetupKeys.Get(context.Background(), sk.Id)
require.NoError(t, err)
assert.Equal(t, false, sk.Valid)
err = c.SetupKeys.Delete(context.Background(), sk.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,74 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// TokensAPI APIs for PATs, do not use directly
type TokensAPI struct {
c *Client
}
// List list user tokens
// See more: https://docs.netbird.io/api/resources/tokens#list-all-tokens
func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAccessToken, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.PersonalAccessToken](resp)
return ret, err
}
// Get get user token info
// See more: https://docs.netbird.io/api/resources/tokens#retrieve-a-token
func (a *TokensAPI) Get(ctx context.Context, userID, tokenID string) (*api.PersonalAccessToken, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens/"+tokenID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.PersonalAccessToken](resp)
return &ret, err
}
// Create generate new PAT for user
// See more: https://docs.netbird.io/api/resources/tokens#create-a-token
func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostApiUsersUserIdTokensJSONRequestBody) (*api.PersonalAccessTokenGenerated, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/tokens", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.PersonalAccessTokenGenerated](resp)
return &ret, err
}
// Delete delete user token
// See more: https://docs.netbird.io/api/resources/tokens#delete-a-token
func (a *TokensAPI) Delete(ctx context.Context, userID, tokenID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID+"/tokens/"+tokenID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,180 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testToken = api.PersonalAccessToken{
Id: "Test",
CreatedAt: time.Time{},
CreatedBy: "meow",
ExpirationDate: time.Time{},
LastUsed: nil,
Name: "wow",
}
testTokenGenerated = api.PersonalAccessTokenGenerated{
PersonalAccessToken: testToken,
PlainToken: "shhh",
}
)
func TestTokens_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.PersonalAccessToken{testToken})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Tokens.List(context.Background(), "meow")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testToken, ret[0])
})
}
func TestTokens_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Tokens.List(context.Background(), "meow")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestTokens_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testToken)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Tokens.Get(context.Background(), "meow", "Test")
require.NoError(t, err)
assert.Equal(t, testToken, *ret)
})
}
func TestTokens_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Tokens.Get(context.Background(), "meow", "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestTokens_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiUsersUserIdTokensJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, 5, req.ExpiresIn)
retBytes, _ := json.Marshal(testTokenGenerated)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Tokens.Create(context.Background(), "meow", api.PostApiUsersUserIdTokensJSONRequestBody{
ExpiresIn: 5,
})
require.NoError(t, err)
assert.Equal(t, testTokenGenerated, *ret)
})
}
func TestTokens_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Tokens.Create(context.Background(), "meow", api.PostApiUsersUserIdTokensJSONRequestBody{
ExpiresIn: 5,
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestTokens_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Tokens.Delete(context.Background(), "meow", "Test")
require.NoError(t, err)
})
}
func TestTokens_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Tokens.Delete(context.Background(), "meow", "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestTokens_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
tokenClear, err := c.Tokens.Create(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", api.PersonalAccessTokenRequest{
Name: "Test",
ExpiresIn: 365,
})
require.NoError(t, err)
assert.Equal(t, "Test", tokenClear.PersonalAccessToken.Name)
tokens, err := c.Tokens.List(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003")
require.NoError(t, err)
assert.Len(t, tokens, 2)
token, err := c.Tokens.Get(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", tokenClear.PersonalAccessToken.Id)
require.NoError(t, err)
assert.Equal(t, "Test", token.Name)
err = c.Tokens.Delete(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", token.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,107 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// UsersAPI APIs for users, do not use directly
type UsersAPI struct {
c *Client
}
// List list all users, only returns one user always
// See more: https://docs.netbird.io/api/resources/users#list-all-users
func (a *UsersAPI) List(ctx context.Context) ([]api.User, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/users", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.User](resp)
return ret, err
}
// Create create user
// See more: https://docs.netbird.io/api/resources/users#create-a-user
func (a *UsersAPI) Create(ctx context.Context, request api.PostApiUsersJSONRequestBody) (*api.User, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/users", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.User](resp)
return &ret, err
}
// Update update user settings
// See more: https://docs.netbird.io/api/resources/users#update-a-user
func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApiUsersUserIdJSONRequestBody) (*api.User, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/users/"+userID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.User](resp)
return &ret, err
}
// Delete delete user
// See more: https://docs.netbird.io/api/resources/users#delete-a-user
func (a *UsersAPI) Delete(ctx context.Context, userID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// ResendInvitation resend user invitation
// See more: https://docs.netbird.io/api/resources/users#resend-user-invitation
func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error {
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/invite", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// Current gets the current user info
// See more: https://docs.netbird.io/api/resources/users#retrieve-current-user
func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/current", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.User](resp)
return &ret, err
}

View File

@@ -0,0 +1,258 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testUser = api.User{
Id: "Test",
AutoGroups: []string{"test-group"},
Email: "test@test.com",
IsBlocked: false,
IsCurrent: ptr(false),
IsServiceUser: ptr(false),
Issued: ptr("api"),
LastLogin: &time.Time{},
Name: "M. Essam",
Role: "user",
Status: api.UserStatusActive,
}
)
func TestUsers_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.User{testUser})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testUser, ret[0])
})
}
func TestUsers_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestUsers_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiUsersJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, []string{"meow"}, req.AutoGroups)
retBytes, _ := json.Marshal(testUser)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Create(context.Background(), api.PostApiUsersJSONRequestBody{
AutoGroups: []string{"meow"},
})
require.NoError(t, err)
assert.Equal(t, testUser, *ret)
})
}
func TestUsers_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Create(context.Background(), api.PostApiUsersJSONRequestBody{
AutoGroups: []string{"meow"},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestUsers_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiUsersUserIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, req.IsBlocked)
retBytes, _ := json.Marshal(testUser)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Update(context.Background(), "Test", api.PutApiUsersUserIdJSONRequestBody{
IsBlocked: true,
})
require.NoError(t, err)
assert.Equal(t, testUser, *ret)
})
}
func TestUsers_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Update(context.Background(), "Test", api.PutApiUsersUserIdJSONRequestBody{
IsBlocked: true,
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestUsers_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Users.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestUsers_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Users.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestUsers_ResendInvitation_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
w.WriteHeader(200)
})
err := c.Users.ResendInvitation(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestUsers_ResendInvitation_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Users.ResendInvitation(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestUsers_Current_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/current", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testUser)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Current(context.Background())
require.NoError(t, err)
assert.Equal(t, testUser, *ret)
})
}
func TestUsers_Current_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/current", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Current(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestUsers_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
// rest client PAT is owner's
current, err := c.Users.Current(context.Background())
require.NoError(t, err)
assert.Equal(t, "a23efe53-63fb-11ec-90d6-0242ac120003", current.Id)
assert.Equal(t, "owner", current.Role)
user, err := c.Users.Create(context.Background(), api.UserCreateRequest{
AutoGroups: []string{},
Email: ptr("test@example.com"),
IsServiceUser: true,
Name: ptr("Nobody"),
Role: "user",
})
require.NoError(t, err)
assert.Equal(t, "Nobody", user.Name)
users, err := c.Users.List(context.Background())
require.NoError(t, err)
assert.NotEmpty(t, users)
user, err = c.Users.Update(context.Background(), user.Id, api.UserRequest{
AutoGroups: []string{},
Role: "admin",
})
require.NoError(t, err)
assert.Equal(t, "admin", user.Role)
err = c.Users.Delete(context.Background(), user.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,45 @@
package domain
import (
"strings"
"golang.org/x/net/idna"
)
// Domain represents a punycode-encoded domain string.
// This should only be converted from a string when the string already is in punycode, otherwise use FromString.
type Domain string
// String converts the Domain to a non-punycode string.
// For an infallible conversion, use SafeString.
func (d Domain) String() (string, error) {
unicode, err := idna.ToUnicode(string(d))
if err != nil {
return "", err
}
return unicode, nil
}
// SafeString converts the Domain to a non-punycode string, falling back to the punycode string if conversion fails.
func (d Domain) SafeString() string {
str, err := d.String()
if err != nil {
return string(d)
}
return str
}
// PunycodeString returns the punycode representation of the Domain.
// This should only be used if a punycode domain is expected but only a string is supported.
func (d Domain) PunycodeString() string {
return string(d)
}
// FromString creates a Domain from a string, converting it to punycode.
func FromString(s string) (Domain, error) {
ascii, err := idna.ToASCII(s)
if err != nil {
return "", err
}
return Domain(strings.ToLower(ascii)), nil
}

View File

@@ -0,0 +1 @@
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=

View File

@@ -0,0 +1,108 @@
package domain
import (
"sort"
"strings"
)
// List is a slice of punycode-encoded domain strings.
type List []Domain
// ToStringList converts a List to a slice of string.
func (d List) ToStringList() ([]string, error) {
var list []string
for _, domain := range d {
s, err := domain.String()
if err != nil {
return nil, err
}
list = append(list, s)
}
return list, nil
}
// ToPunycodeList converts the List to a slice of Punycode-encoded domain strings.
func (d List) ToPunycodeList() []string {
var list []string
for _, domain := range d {
list = append(list, string(domain))
}
return list
}
// ToSafeStringList converts the List to a slice of non-punycode strings.
// If a domain cannot be converted, the original string is used.
func (d List) ToSafeStringList() []string {
var list []string
for _, domain := range d {
list = append(list, domain.SafeString())
}
return list
}
// String converts List to a comma-separated string.
func (d List) String() (string, error) {
list, err := d.ToStringList()
if err != nil {
return "", err
}
return strings.Join(list, ", "), nil
}
// SafeString converts List to a comma-separated non-punycode string.
// If a domain cannot be converted, the original string is used.
func (d List) SafeString() string {
str, err := d.String()
if err != nil {
return d.PunycodeString()
}
return str
}
// PunycodeString converts the List to a comma-separated string of Punycode-encoded domains.
func (d List) PunycodeString() string {
return strings.Join(d.ToPunycodeList(), ", ")
}
func (d List) Equal(domains List) bool {
if len(d) != len(domains) {
return false
}
sort.Slice(d, func(i, j int) bool {
return d[i] < d[j]
})
sort.Slice(domains, func(i, j int) bool {
return domains[i] < domains[j]
})
for i, domain := range d {
if domain != domains[i] {
return false
}
}
return true
}
// FromStringList creates a DomainList from a slice of string.
func FromStringList(s []string) (List, error) {
var dl List
for _, domain := range s {
d, err := FromString(domain)
if err != nil {
return nil, err
}
dl = append(dl, d)
}
return dl, nil
}
// FromPunycodeList creates a List from a slice of Punycode-encoded domain strings.
func FromPunycodeList(s []string) List {
var dl List
for _, domain := range s {
dl = append(dl, Domain(strings.ToLower(domain)))
}
return dl
}

View File

@@ -0,0 +1,49 @@
package domain
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_EqualReturnsTrueForIdenticalLists(t *testing.T) {
list1 := List{"domain1", "domain2", "domain3"}
list2 := List{"domain1", "domain2", "domain3"}
assert.True(t, list1.Equal(list2))
}
func Test_EqualReturnsFalseForDifferentLengths(t *testing.T) {
list1 := List{"domain1", "domain2"}
list2 := List{"domain1", "domain2", "domain3"}
assert.False(t, list1.Equal(list2))
}
func Test_EqualReturnsFalseForDifferentElements(t *testing.T) {
list1 := List{"domain1", "domain2", "domain3"}
list2 := List{"domain1", "domain4", "domain3"}
assert.False(t, list1.Equal(list2))
}
func Test_EqualReturnsTrueForUnsortedIdenticalLists(t *testing.T) {
list1 := List{"domain3", "domain1", "domain2"}
list2 := List{"domain1", "domain2", "domain3"}
assert.True(t, list1.Equal(list2))
}
func Test_EqualReturnsFalseForEmptyAndNonEmptyList(t *testing.T) {
list1 := List{}
list2 := List{"domain1"}
assert.False(t, list1.Equal(list2))
}
func Test_EqualReturnsTrueForBothEmptyLists(t *testing.T) {
list1 := List{}
list2 := List{}
assert.True(t, list1.Equal(list2))
}

View File

@@ -0,0 +1,56 @@
package domain
import (
"fmt"
"regexp"
"strings"
)
const maxDomains = 32
var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
func ValidateDomains(domains []string) (List, error) {
if len(domains) == 0 {
return nil, fmt.Errorf("domains list is empty")
}
if len(domains) > maxDomains {
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
}
var domainList List
for _, d := range domains {
// handles length and idna conversion
punycode, err := FromString(d)
if err != nil {
return domainList, fmt.Errorf("convert domain to punycode: %s: %w", d, err)
}
if !domainRegex.MatchString(string(punycode)) {
return domainList, fmt.Errorf("invalid domain format: %s", d)
}
domainList = append(domainList, punycode)
}
return domainList, nil
}
// ValidateDomainsList checks if each domain in the list is valid
func ValidateDomainsList(domains []string) error {
if len(domains) == 0 {
return nil
}
if len(domains) > maxDomains {
return fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
}
for _, d := range domains {
d := strings.ToLower(d)
if !domainRegex.MatchString(d) {
return fmt.Errorf("invalid domain format: %s", d)
}
}
return nil
}

View File

@@ -0,0 +1,185 @@
package domain
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidateDomains(t *testing.T) {
tests := []struct {
name string
domains []string
expected List
wantErr bool
}{
{
name: "Empty list",
domains: nil,
expected: nil,
wantErr: true,
},
{
name: "Valid ASCII domain",
domains: []string{"sub.ex-ample.com"},
expected: List{"sub.ex-ample.com"},
wantErr: false,
},
{
name: "Valid Unicode domain",
domains: []string{"münchen.de"},
expected: List{"xn--mnchen-3ya.de"},
wantErr: false,
},
{
name: "Valid Unicode, all labels",
domains: []string{"中国.中国.中国"},
expected: List{"xn--fiqs8s.xn--fiqs8s.xn--fiqs8s"},
wantErr: false,
},
{
name: "With underscores",
domains: []string{"_jabber._tcp.gmail.com"},
expected: List{"_jabber._tcp.gmail.com"},
wantErr: false,
},
{
name: "Invalid domain format",
domains: []string{"-example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid domain format 2",
domains: []string{"example.com-"},
expected: nil,
wantErr: true,
},
{
name: "Multiple domains valid and invalid",
domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"},
expected: List{"google.com"},
wantErr: true,
},
{
name: "Valid wildcard domain",
domains: []string{"*.example.com"},
expected: List{"*.example.com"},
wantErr: false,
},
{
name: "Wildcard with dot domain",
domains: []string{".*.example.com"},
expected: nil,
wantErr: true,
},
{
name: "Wildcard with dot domain",
domains: []string{".*.example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid wildcard domain",
domains: []string{"a.*.example.com"},
expected: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ValidateDomains(tt.domains)
assert.Equal(t, tt.wantErr, err != nil)
assert.Equal(t, got, tt.expected)
})
}
}
func TestValidateDomainsList(t *testing.T) {
validDomains := make([]string, maxDomains)
for i := range maxDomains {
validDomains[i] = fmt.Sprintf("example%d.com", i)
}
tests := []struct {
name string
domains []string
wantErr bool
}{
{
name: "Empty list",
domains: nil,
wantErr: false,
},
{
name: "Single valid ASCII domain",
domains: []string{"sub.ex-ample.com"},
wantErr: false,
},
{
name: "Underscores in labels",
domains: []string{"_jabber._tcp.gmail.com"},
wantErr: false,
},
{
// Unlike ValidateDomains (which converts to punycode),
// ValidateDomainsStrSlice will fail on non-ASCII domain chars.
name: "Unicode domain fails (no punycode conversion)",
domains: []string{"münchen.de"},
wantErr: true,
},
{
name: "Invalid domain format - leading dash",
domains: []string{"-example.com"},
wantErr: true,
},
{
name: "Invalid domain format - trailing dash",
domains: []string{"example-.com"},
wantErr: true,
},
{
name: "Multiple domains with a valid one, then invalid",
domains: []string{"google.com", "invalid_domain.com-"},
wantErr: true,
},
{
name: "Valid wildcard domain",
domains: []string{"*.example.com"},
wantErr: false,
},
{
name: "Wildcard with leading dot - invalid",
domains: []string{".*.example.com"},
wantErr: true,
},
{
name: "Invalid wildcard with multiple asterisks",
domains: []string{"a.*.example.com"},
wantErr: true,
},
{
name: "Exactly maxDomains items (valid)",
domains: validDomains,
wantErr: false,
},
{
name: "Exceeds maxDomains items",
domains: append(validDomains, "extra.com"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateDomainsList(tt.domains)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,17 @@
#!/bin/bash
set -e
if ! which realpath > /dev/null 2>&1
then
echo realpath is not installed
echo run: brew install coreutils
exit 1
fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./management.proto --go_out=../ --go-grpc_out=../
cd "$old_pwd"

View File

@@ -0,0 +1,2 @@
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,542 @@
syntax = "proto3";
import "google/protobuf/timestamp.proto";
import "google/protobuf/duration.proto";
option go_package = "/proto";
package management;
service ManagementService {
// Login logs in peer. In case server returns codes.PermissionDenied this endpoint can be used to register Peer providing LoginRequest.setupKey
// Returns encrypted LoginResponse in EncryptedMessage.Body
rpc Login(EncryptedMessage) returns (EncryptedMessage) {}
// Sync enables peer synchronization. Each peer that is connected to this stream will receive updates from the server.
// For example, if a new peer has been added to an account all other connected peers will receive this peer's Wireguard public key as an update
// The initial SyncResponse contains all of the available peers so the local state can be refreshed
// Returns encrypted SyncResponse in EncryptedMessage.Body
rpc Sync(EncryptedMessage) returns (stream EncryptedMessage) {}
// Exposes a Wireguard public key of the Management service.
// This key is used to support message encryption between client and server
rpc GetServerKey(Empty) returns (ServerKeyResponse) {}
// health check endpoint
rpc isHealthy(Empty) returns (Empty) {}
// Exposes a device authorization flow information
// This is used for initiating a Oauth 2 device authorization grant flow
// which will be used by our clients to Login.
// EncryptedMessage of the request has a body of DeviceAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of DeviceAuthorizationFlow.
rpc GetDeviceAuthorizationFlow(EncryptedMessage) returns (EncryptedMessage) {}
// Exposes a PKCE authorization code flow information
// This is used for initiating a Oauth 2 authorization grant flow
// with Proof Key for Code Exchange (PKCE) which will be used by our clients to Login.
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
rpc GetPKCEAuthorizationFlow(EncryptedMessage) returns (EncryptedMessage) {}
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
rpc SyncMeta(EncryptedMessage) returns (Empty) {}
// Logout logs out the peer and removes it from the management server
rpc Logout(EncryptedMessage) returns (Empty) {}
}
message EncryptedMessage {
// Wireguard public key
string wgPubKey = 1;
// encrypted message Body
bytes body = 2;
// Version of the Netbird Management Service protocol
int32 version = 3;
}
message SyncRequest {
// Meta data of the peer
PeerSystemMeta meta = 1;
}
// SyncResponse represents a state that should be applied to the local peer (e.g. Netbird servers config as well as local peer and remote peers configs)
message SyncResponse {
// Global config
NetbirdConfig netbirdConfig = 1;
// Deprecated. Use NetworkMap.PeerConfig
PeerConfig peerConfig = 2;
// Deprecated. Use NetworkMap.RemotePeerConfig
repeated RemotePeerConfig remotePeers = 3;
// Indicates whether remotePeers array is empty or not to bypass protobuf null and empty array equality.
// Deprecated. Use NetworkMap.remotePeersIsEmpty
bool remotePeersIsEmpty = 4;
NetworkMap NetworkMap = 5;
// Posture checks to be evaluated by client
repeated Checks Checks = 6;
}
message SyncMetaRequest {
// Meta data of the peer
PeerSystemMeta meta = 1;
}
message LoginRequest {
// Pre-authorized setup key (can be empty)
string setupKey = 1;
// Meta data of the peer (e.g. name, os_name, os_version,
PeerSystemMeta meta = 2;
// SSO token (can be empty)
string jwtToken = 3;
// Can be absent for now.
PeerKeys peerKeys = 4;
repeated string dnsLabels = 5;
}
// PeerKeys is additional peer info like SSH pub key and WireGuard public key.
// This message is sent on Login or register requests, or when a key rotation has to happen.
message PeerKeys {
// sshPubKey represents a public SSH key of the peer. Can be absent.
bytes sshPubKey = 1;
// wgPubKey represents a public WireGuard key of the peer. Can be absent.
bytes wgPubKey = 2;
}
// Environment is part of the PeerSystemMeta and describes the environment the agent is running in.
message Environment {
// cloud is the cloud provider the agent is running in if applicable.
string cloud = 1;
// platform is the platform the agent is running on if applicable.
string platform = 2;
}
// File represents a file on the system.
message File {
// path is the path to the file.
string path = 1;
// exist indicate whether the file exists.
bool exist = 2;
// processIsRunning indicates whether the file is a running process or not.
bool processIsRunning = 3;
}
message Flags {
bool rosenpassEnabled = 1;
bool rosenpassPermissive = 2;
bool serverSSHAllowed = 3;
bool disableClientRoutes = 4;
bool disableServerRoutes = 5;
bool disableDNS = 6;
bool disableFirewall = 7;
bool blockLANAccess = 8;
bool blockInbound = 9;
bool lazyConnectionEnabled = 10;
}
// PeerSystemMeta is machine meta data like OS and version.
message PeerSystemMeta {
string hostname = 1;
string goOS = 2;
string kernel = 3;
string core = 4;
string platform = 5;
string OS = 6;
string netbirdVersion = 7;
string uiVersion = 8;
string kernelVersion = 9;
string OSVersion = 10;
repeated NetworkAddress networkAddresses = 11;
string sysSerialNumber = 12;
string sysProductName = 13;
string sysManufacturer = 14;
Environment environment = 15;
repeated File files = 16;
Flags flags = 17;
}
message LoginResponse {
// Global config
NetbirdConfig netbirdConfig = 1;
// Peer local config
PeerConfig peerConfig = 2;
// Posture checks to be evaluated by client
repeated Checks Checks = 3;
}
message ServerKeyResponse {
// Server's Wireguard public key
string key = 1;
// Key expiration timestamp after which the key should be fetched again by the client
google.protobuf.Timestamp expiresAt = 2;
// Version of the Netbird Management Service protocol
int32 version = 3;
}
message Empty {}
// NetbirdConfig is a common configuration of any Netbird peer. It contains STUN, TURN, Signal and Management servers configurations
message NetbirdConfig {
// a list of STUN servers
repeated HostConfig stuns = 1;
// a list of TURN servers
repeated ProtectedHostConfig turns = 2;
// a Signal server config
HostConfig signal = 3;
RelayConfig relay = 4;
FlowConfig flow = 5;
}
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
message HostConfig {
// URI of the resource e.g. turns://stun.netbird.io:4430 or signal.netbird.io:10000
string uri = 1;
Protocol protocol = 2;
enum Protocol {
UDP = 0;
TCP = 1;
HTTP = 2;
HTTPS = 3;
DTLS = 4;
}
}
message RelayConfig {
repeated string urls = 1;
string tokenPayload = 2;
string tokenSignature = 3;
}
message FlowConfig {
string url = 1;
string tokenPayload = 2;
string tokenSignature = 3;
google.protobuf.Duration interval = 4;
bool enabled = 5;
// counters determines if flow packets and bytes counters should be sent
bool counters = 6;
// exitNodeCollection determines if event collection on exit nodes should be enabled
bool exitNodeCollection = 7;
// dnsCollection determines if DNS event collection should be enabled
bool dnsCollection = 8;
}
// ProtectedHostConfig is similar to HostConfig but has additional user and password
// Mostly used for TURN servers
message ProtectedHostConfig {
HostConfig hostConfig = 1;
string user = 2;
string password = 3;
}
// PeerConfig represents a configuration of a "our" peer.
// The properties are used to configure local Wireguard
message PeerConfig {
// Peer's virtual IP address within the Netbird VPN (a Wireguard address config)
string address = 1;
// Netbird DNS server (a Wireguard DNS config)
string dns = 2;
// SSHConfig of the peer.
SSHConfig sshConfig = 3;
// Peer fully qualified domain name
string fqdn = 4;
bool RoutingPeerDnsResolutionEnabled = 5;
bool LazyConnectionEnabled = 6;
}
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
message NetworkMap {
// Serial is an ID of the network state to be used by clients to order updates.
// The larger the Serial the newer the configuration.
// E.g. the client app should keep track of this id locally and discard all the configurations with a lower value
uint64 Serial = 1;
// PeerConfig represents configuration of a peer
PeerConfig peerConfig = 2;
// RemotePeerConfig represents a list of remote peers that the receiver can connect to
repeated RemotePeerConfig remotePeers = 3;
// Indicates whether remotePeers array is empty or not to bypass protobuf null and empty array equality.
bool remotePeersIsEmpty = 4;
// List of routes to be applied
repeated Route Routes = 5;
// DNS config to be applied
DNSConfig DNSConfig = 6;
// RemotePeerConfig represents a list of remote peers that the receiver can connect to
repeated RemotePeerConfig offlinePeers = 7;
// FirewallRule represents a list of firewall rules to be applied to peer
repeated FirewallRule FirewallRules = 8;
// firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality.
bool firewallRulesIsEmpty = 9;
// RoutesFirewallRules represents a list of routes firewall rules to be applied to peer
repeated RouteFirewallRule routesFirewallRules = 10;
// RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality.
bool routesFirewallRulesIsEmpty = 11;
repeated ForwardingRule forwardingRules = 12;
}
// RemotePeerConfig represents a configuration of a remote peer.
// The properties are used to configure WireGuard Peers sections
message RemotePeerConfig {
// A WireGuard public key of a remote peer
string wgPubKey = 1;
// WireGuard allowed IPs of a remote peer e.g. [10.30.30.1/32]
repeated string allowedIps = 2;
// SSHConfig is a SSH config of the remote peer. SSHConfig.sshPubKey should be ignored because peer knows it's SSH key.
SSHConfig sshConfig = 3;
// Peer fully qualified domain name
string fqdn = 4;
string agentVersion = 5;
}
// SSHConfig represents SSH configurations of a peer.
message SSHConfig {
// sshEnabled indicates whether a SSH server is enabled on this peer
bool sshEnabled = 1;
// sshPubKey is a SSH public key of a peer to be added to authorized_hosts.
// This property should be ignore if SSHConfig comes from PeerConfig.
bytes sshPubKey = 2;
}
// DeviceAuthorizationFlowRequest empty struct for future expansion
message DeviceAuthorizationFlowRequest {}
// DeviceAuthorizationFlow represents Device Authorization Flow information
// that can be used by the client to login initiate a Oauth 2.0 device authorization grant flow
// see https://datatracker.ietf.org/doc/html/rfc8628
message DeviceAuthorizationFlow {
// An IDP provider , (eg. Auth0)
provider Provider = 1;
ProviderConfig ProviderConfig = 2;
enum provider {
HOSTED = 0;
}
}
// PKCEAuthorizationFlowRequest empty struct for future expansion
message PKCEAuthorizationFlowRequest {}
// PKCEAuthorizationFlow represents Authorization Code Flow information
// that can be used by the client to login initiate a Oauth 2.0 authorization code grant flow
// with Proof Key for Code Exchange (PKCE). See https://datatracker.ietf.org/doc/html/rfc7636
message PKCEAuthorizationFlow {
ProviderConfig ProviderConfig = 1;
}
// ProviderConfig has all attributes needed to initiate a device/pkce authorization flow
message ProviderConfig {
// An IDP application client id
string ClientID = 1;
// An IDP application client secret
string ClientSecret = 2;
// An IDP API domain
// Deprecated. Use a DeviceAuthEndpoint and TokenEndpoint
string Domain = 3;
// An Audience for validation
string Audience = 4;
// DeviceAuthEndpoint is an endpoint to request device authentication code.
string DeviceAuthEndpoint = 5;
// TokenEndpoint is an endpoint to request auth token.
string TokenEndpoint = 6;
// Scopes provides the scopes to be included in the token request
string Scope = 7;
// UseIDToken indicates if the id token should be used for authentication
bool UseIDToken = 8;
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code.
string AuthorizationEndpoint = 9;
// RedirectURLs handles authorization code from IDP manager
repeated string RedirectURLs = 10;
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
bool DisablePromptLogin = 11;
// LoginFlags sets the PKCE flow login details
uint32 LoginFlag = 12;
}
// Route represents a route.Route object
message Route {
string ID = 1;
string Network = 2;
int64 NetworkType = 3;
string Peer = 4;
int64 Metric = 5;
bool Masquerade = 6;
string NetID = 7;
repeated string Domains = 8;
bool keepRoute = 9;
}
// DNSConfig represents a dns.Update
message DNSConfig {
bool ServiceEnable = 1;
repeated NameServerGroup NameServerGroups = 2;
repeated CustomZone CustomZones = 3;
}
// CustomZone represents a dns.CustomZone
message CustomZone {
string Domain = 1;
repeated SimpleRecord Records = 2;
}
// SimpleRecord represents a dns.SimpleRecord
message SimpleRecord {
string Name = 1;
int64 Type = 2;
string Class = 3;
int64 TTL = 4;
string RData = 5;
}
// NameServerGroup represents a dns.NameServerGroup
message NameServerGroup {
repeated NameServer NameServers = 1;
bool Primary = 2;
repeated string Domains = 3;
bool SearchDomainsEnabled = 4;
}
// NameServer represents a dns.NameServer
message NameServer {
string IP = 1;
int64 NSType = 2;
int64 Port = 3;
}
enum RuleProtocol {
UNKNOWN = 0;
ALL = 1;
TCP = 2;
UDP = 3;
ICMP = 4;
CUSTOM = 5;
}
enum RuleDirection {
IN = 0;
OUT = 1;
}
enum RuleAction {
ACCEPT = 0;
DROP = 1;
}
// FirewallRule represents a firewall rule
message FirewallRule {
string PeerIP = 1;
RuleDirection Direction = 2;
RuleAction Action = 3;
RuleProtocol Protocol = 4;
string Port = 5;
PortInfo PortInfo = 6;
// PolicyID is the ID of the policy that this rule belongs to
bytes PolicyID = 7;
}
message NetworkAddress {
string netIP = 1;
string mac = 2;
}
message Checks {
repeated string Files = 1;
}
message PortInfo {
oneof portSelection {
uint32 port = 1;
Range range = 2;
}
message Range {
uint32 start = 1;
uint32 end = 2;
}
}
// RouteFirewallRule signifies a firewall rule applicable for a routed network.
message RouteFirewallRule {
// sourceRanges IP ranges of the routing peers.
repeated string sourceRanges = 1;
// Action to be taken by the firewall when the rule is applicable.
RuleAction action = 2;
// Network prefix for the routed network.
string destination = 3;
// Protocol of the routed network.
RuleProtocol protocol = 4;
// Details about the port.
PortInfo portInfo = 5;
// IsDynamic indicates if the route is a DNS route.
bool isDynamic = 6;
// Domains is a list of domains for which the rule is applicable.
repeated string domains = 7;
// CustomProtocol is a custom protocol ID.
uint32 customProtocol = 8;
// PolicyID is the ID of the policy that this rule belongs to
bytes PolicyID = 9;
// RouteID is the ID of the route that this rule belongs to
string RouteID = 10;
}
message ForwardingRule {
// Protocol of the forwarding rule
RuleProtocol protocol = 1;
// portInfo is the ingress destination port information, where the traffic arrives in the gateway node
PortInfo destinationPort = 2;
// IP address of the translated address (remote peer) to send traffic to
bytes translatedAddress = 3;
// Translated port information, where the traffic should be forwarded to
PortInfo translatedPort = 4;
}

View File

@@ -0,0 +1,429 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
package proto
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// ManagementServiceClient is the client API for ManagementService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type ManagementServiceClient interface {
// Login logs in peer. In case server returns codes.PermissionDenied this endpoint can be used to register Peer providing LoginRequest.setupKey
// Returns encrypted LoginResponse in EncryptedMessage.Body
Login(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
// Sync enables peer synchronization. Each peer that is connected to this stream will receive updates from the server.
// For example, if a new peer has been added to an account all other connected peers will receive this peer's Wireguard public key as an update
// The initial SyncResponse contains all of the available peers so the local state can be refreshed
// Returns encrypted SyncResponse in EncryptedMessage.Body
Sync(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (ManagementService_SyncClient, error)
// Exposes a Wireguard public key of the Management service.
// This key is used to support message encryption between client and server
GetServerKey(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*ServerKeyResponse, error)
// health check endpoint
IsHealthy(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
// Exposes a device authorization flow information
// This is used for initiating a Oauth 2 device authorization grant flow
// which will be used by our clients to Login.
// EncryptedMessage of the request has a body of DeviceAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of DeviceAuthorizationFlow.
GetDeviceAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
// Exposes a PKCE authorization code flow information
// This is used for initiating a Oauth 2 authorization grant flow
// with Proof Key for Code Exchange (PKCE) which will be used by our clients to Login.
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
GetPKCEAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
// Logout logs out the peer and removes it from the management server
Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
}
type managementServiceClient struct {
cc grpc.ClientConnInterface
}
func NewManagementServiceClient(cc grpc.ClientConnInterface) ManagementServiceClient {
return &managementServiceClient{cc}
}
func (c *managementServiceClient) Login(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
out := new(EncryptedMessage)
err := c.cc.Invoke(ctx, "/management.ManagementService/Login", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) Sync(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (ManagementService_SyncClient, error) {
stream, err := c.cc.NewStream(ctx, &ManagementService_ServiceDesc.Streams[0], "/management.ManagementService/Sync", opts...)
if err != nil {
return nil, err
}
x := &managementServiceSyncClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type ManagementService_SyncClient interface {
Recv() (*EncryptedMessage, error)
grpc.ClientStream
}
type managementServiceSyncClient struct {
grpc.ClientStream
}
func (x *managementServiceSyncClient) Recv() (*EncryptedMessage, error) {
m := new(EncryptedMessage)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func (c *managementServiceClient) GetServerKey(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*ServerKeyResponse, error) {
out := new(ServerKeyResponse)
err := c.cc.Invoke(ctx, "/management.ManagementService/GetServerKey", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) IsHealthy(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := c.cc.Invoke(ctx, "/management.ManagementService/isHealthy", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) GetDeviceAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
out := new(EncryptedMessage)
err := c.cc.Invoke(ctx, "/management.ManagementService/GetDeviceAuthorizationFlow", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) GetPKCEAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
out := new(EncryptedMessage)
err := c.cc.Invoke(ctx, "/management.ManagementService/GetPKCEAuthorizationFlow", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := c.cc.Invoke(ctx, "/management.ManagementService/SyncMeta", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := c.cc.Invoke(ctx, "/management.ManagementService/Logout", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// ManagementServiceServer is the server API for ManagementService service.
// All implementations must embed UnimplementedManagementServiceServer
// for forward compatibility
type ManagementServiceServer interface {
// Login logs in peer. In case server returns codes.PermissionDenied this endpoint can be used to register Peer providing LoginRequest.setupKey
// Returns encrypted LoginResponse in EncryptedMessage.Body
Login(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
// Sync enables peer synchronization. Each peer that is connected to this stream will receive updates from the server.
// For example, if a new peer has been added to an account all other connected peers will receive this peer's Wireguard public key as an update
// The initial SyncResponse contains all of the available peers so the local state can be refreshed
// Returns encrypted SyncResponse in EncryptedMessage.Body
Sync(*EncryptedMessage, ManagementService_SyncServer) error
// Exposes a Wireguard public key of the Management service.
// This key is used to support message encryption between client and server
GetServerKey(context.Context, *Empty) (*ServerKeyResponse, error)
// health check endpoint
IsHealthy(context.Context, *Empty) (*Empty, error)
// Exposes a device authorization flow information
// This is used for initiating a Oauth 2 device authorization grant flow
// which will be used by our clients to Login.
// EncryptedMessage of the request has a body of DeviceAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of DeviceAuthorizationFlow.
GetDeviceAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
// Exposes a PKCE authorization code flow information
// This is used for initiating a Oauth 2 authorization grant flow
// with Proof Key for Code Exchange (PKCE) which will be used by our clients to Login.
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
SyncMeta(context.Context, *EncryptedMessage) (*Empty, error)
// Logout logs out the peer and removes it from the management server
Logout(context.Context, *EncryptedMessage) (*Empty, error)
mustEmbedUnimplementedManagementServiceServer()
}
// UnimplementedManagementServiceServer must be embedded to have forward compatible implementations.
type UnimplementedManagementServiceServer struct {
}
func (UnimplementedManagementServiceServer) Login(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method Login not implemented")
}
func (UnimplementedManagementServiceServer) Sync(*EncryptedMessage, ManagementService_SyncServer) error {
return status.Errorf(codes.Unimplemented, "method Sync not implemented")
}
func (UnimplementedManagementServiceServer) GetServerKey(context.Context, *Empty) (*ServerKeyResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetServerKey not implemented")
}
func (UnimplementedManagementServiceServer) IsHealthy(context.Context, *Empty) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method IsHealthy not implemented")
}
func (UnimplementedManagementServiceServer) GetDeviceAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetDeviceAuthorizationFlow not implemented")
}
func (UnimplementedManagementServiceServer) GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented")
}
func (UnimplementedManagementServiceServer) SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented")
}
func (UnimplementedManagementServiceServer) Logout(context.Context, *EncryptedMessage) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented")
}
func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {}
// UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to ManagementServiceServer will
// result in compilation errors.
type UnsafeManagementServiceServer interface {
mustEmbedUnimplementedManagementServiceServer()
}
func RegisterManagementServiceServer(s grpc.ServiceRegistrar, srv ManagementServiceServer) {
s.RegisterService(&ManagementService_ServiceDesc, srv)
}
func _ManagementService_Login_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).Login(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/Login",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).Login(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_Sync_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(EncryptedMessage)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(ManagementServiceServer).Sync(m, &managementServiceSyncServer{stream})
}
type ManagementService_SyncServer interface {
Send(*EncryptedMessage) error
grpc.ServerStream
}
type managementServiceSyncServer struct {
grpc.ServerStream
}
func (x *managementServiceSyncServer) Send(m *EncryptedMessage) error {
return x.ServerStream.SendMsg(m)
}
func _ManagementService_GetServerKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).GetServerKey(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/GetServerKey",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).GetServerKey(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_IsHealthy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).IsHealthy(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/isHealthy",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).IsHealthy(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_GetDeviceAuthorizationFlow_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).GetDeviceAuthorizationFlow(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/GetDeviceAuthorizationFlow",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).GetDeviceAuthorizationFlow(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_GetPKCEAuthorizationFlow_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).GetPKCEAuthorizationFlow(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/GetPKCEAuthorizationFlow",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).GetPKCEAuthorizationFlow(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_SyncMeta_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).SyncMeta(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/SyncMeta",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).SyncMeta(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_Logout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).Logout(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/Logout",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).Logout(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
// ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var ManagementService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "management.ManagementService",
HandlerType: (*ManagementServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Login",
Handler: _ManagementService_Login_Handler,
},
{
MethodName: "GetServerKey",
Handler: _ManagementService_GetServerKey_Handler,
},
{
MethodName: "isHealthy",
Handler: _ManagementService_IsHealthy_Handler,
},
{
MethodName: "GetDeviceAuthorizationFlow",
Handler: _ManagementService_GetDeviceAuthorizationFlow_Handler,
},
{
MethodName: "GetPKCEAuthorizationFlow",
Handler: _ManagementService_GetPKCEAuthorizationFlow_Handler,
},
{
MethodName: "SyncMeta",
Handler: _ManagementService_SyncMeta_Handler,
},
{
MethodName: "Logout",
Handler: _ManagementService_Logout_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "Sync",
Handler: _ManagementService_Sync_Handler,
ServerStreams: true,
},
},
Metadata: "management.proto",
}