[misc] Move shared components to shared directory (#4286)

Moved the following directories:

```
  - management/client → shared/management/client
  - management/domain → shared/management/domain
  - management/proto → shared/management/proto
  - signal/client → shared/signal/client
  - signal/proto → shared/signal/proto
  - relay/client → shared/relay/client
  - relay/auth → shared/relay/auth
```

and adjusted import paths
This commit is contained in:
Viktor Liu
2025-08-05 15:22:58 +02:00
committed by GitHub
parent 3d3c4c5844
commit 1d5e871bdf
172 changed files with 181 additions and 152 deletions

View File

@@ -0,0 +1,26 @@
package client
import (
"context"
"io"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
)
type Client interface {
io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
IsHealthy() bool
SyncMeta(sysInfo *system.Info) error
Logout() error
}

View File

@@ -0,0 +1,554 @@
package client
import (
"context"
"net"
"os"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
mgmt "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/util"
)
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
func TestMain(m *testing.M) {
_ = util.InitLog("debug", util.LogConsole)
code := m.Run()
os.Exit(code)
}
func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
t.Helper()
level, _ := log.ParseLevel("debug")
log.SetLevel(level)
config := &types.Config{}
_, err := util.ReadJson("../server/testdata/management.json", config)
if err != nil {
t.Fatal(err)
}
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir())
if err != nil {
t.Fatal(err)
}
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.
EXPECT().
GetSettings(
gomock.Any(),
gomock.Any(),
gomock.Any(),
).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.
EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
permissionsManagerMock := permissions.NewMockManager(ctrl)
permissionsManagerMock.
EXPECT().
ValidateUserPermissions(
gomock.Any(),
gomock.Any(),
gomock.Any(),
gomock.Any(),
gomock.Any(),
).
Return(true, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
if err != nil {
t.Fatal(err)
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
return
}
}()
return s, lis
}
func startMockManagement(t *testing.T) (*grpc.Server, net.Listener, *mock_server.ManagementServiceServerMock, wgtypes.Key) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
s := grpc.NewServer()
serverKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
mgmtMockServer := &mock_server.ManagementServiceServerMock{
GetServerKeyFunc: func(context.Context, *mgmtProto.Empty) (*mgmtProto.ServerKeyResponse, error) {
response := &mgmtProto.ServerKeyResponse{
Key: serverKey.PublicKey().String(),
}
return response, nil
},
}
mgmtProto.RegisterManagementServiceServer(s, mgmtMockServer)
go func() {
if err := s.Serve(lis); err != nil {
t.Error(err)
return
}
}()
return s, lis, mgmtMockServer, serverKey
}
func closeManagementSilently(s *grpc.Server, listener net.Listener) {
s.GracefulStop()
err := listener.Close()
if err != nil {
log.Warnf("error while closing management listener %v", err)
return
}
}
func TestClient_GetServerPublicKey(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Error("couldn't retrieve management public key")
}
if key == nil {
t.Error("got an empty management public key")
}
}
func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Fatal(err)
}
sysInfo := system.GetInfo(context.TODO())
_, err = client.Login(*key, sysInfo, nil, nil)
if err == nil {
t.Error("expecting err on unregistered login, got nil")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.PermissionDenied {
t.Errorf("expecting err code %d denied on unregistered login got %d", codes.PermissionDenied, s.Code())
}
}
func TestClient_LoginRegistered(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
key, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
if resp == nil {
t.Error("expecting non nil response, got nil")
}
}
func TestClient_Sync(t *testing.T) {
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
s, listener := startManagement(t)
defer closeManagementSilently(s, listener)
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
if err != nil {
t.Fatal(err)
}
serverKey, err := client.GetServerPublicKey()
if err != nil {
t.Error(err)
}
info := system.GetInfo(context.TODO())
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
if err != nil {
t.Error(err)
}
// create and register second peer (we should receive on Sync request)
remoteKey, err := wgtypes.GenerateKey()
if err != nil {
t.Error(err)
}
remoteClient, err := NewClient(context.TODO(), listener.Addr().String(), remoteKey, false)
if err != nil {
t.Fatal(err)
}
info = system.GetInfo(context.TODO())
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
if err != nil {
t.Fatal(err)
}
ch := make(chan *mgmtProto.SyncResponse, 1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
err = client.Sync(ctx, info, func(msg *mgmtProto.SyncResponse) error {
ch <- msg
return nil
})
if err != nil {
return
}
}()
select {
case resp := <-ch:
if resp.GetPeerConfig() == nil {
t.Error("expecting non nil PeerConfig got nil")
}
if resp.GetNetbirdConfig() == nil {
t.Error("expecting non nil NetbirdConfig got nil")
}
if len(resp.GetRemotePeers()) != 1 {
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
return
}
if resp.GetRemotePeersIsEmpty() == true {
t.Error("expecting RemotePeers property to be false, got true")
}
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey())
}
case <-time.After(3 * time.Second):
t.Error("timeout waiting for test to finish")
}
}
func Test_SystemMetaDataFromClient(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
serverAddr := lis.Addr().String()
ctx := context.Background()
testClient, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
key, err := testClient.GetServerPublicKey()
if err != nil {
t.Fatalf("error while getting server public key from testclient, %v", err)
}
var actualMeta *mgmtProto.PeerSystemMeta
var actualValidKey string
var wg sync.WaitGroup
wg.Add(1)
mgmtMockServer.LoginFunc = func(ctx context.Context, msg *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey())
if err != nil {
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey)
return nil, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", msg.WgPubKey)
}
loginReq := &mgmtProto.LoginRequest{}
err = encryption.DecryptMessage(peerKey, serverKey, msg.Body, loginReq)
if err != nil {
log.Fatal(err)
}
actualMeta = loginReq.GetMeta()
actualValidKey = loginReq.GetSetupKey()
wg.Done()
loginResp := &mgmtProto.LoginResponse{}
encryptedResp, err := encryption.EncryptMessage(peerKey, serverKey, loginResp)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
info := system.GetInfo(context.TODO())
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
if err != nil {
t.Errorf("error while trying to register client: %v", err)
}
wg.Wait()
protoNetAddr := make([]*mgmtProto.NetworkAddress, 0, len(info.NetworkAddresses))
for _, addr := range info.NetworkAddresses {
protoNetAddr = append(protoNetAddr, &mgmtProto.NetworkAddress{
NetIP: addr.NetIP.String(),
Mac: addr.Mac,
})
}
expectedMeta := &mgmtProto.PeerSystemMeta{
Hostname: info.Hostname,
GoOS: info.GoOS,
Kernel: info.Kernel,
Platform: info.Platform,
OS: info.OS,
Core: info.OSVersion,
OSVersion: info.OSVersion,
NetbirdVersion: info.NetbirdVersion,
KernelVersion: info.KernelVersion,
NetworkAddresses: protoNetAddr,
SysSerialNumber: info.SystemSerialNumber,
SysProductName: info.SystemProductName,
SysManufacturer: info.SystemManufacturer,
Environment: &mgmtProto.Environment{Cloud: info.Environment.Cloud, Platform: info.Environment.Platform},
}
assert.Equal(t, ValidKey, actualValidKey)
if !isEqual(expectedMeta, actualMeta) {
t.Errorf("expected and actual meta are not equal")
}
}
func isEqual(a, b *mgmtProto.PeerSystemMeta) bool {
if len(a.NetworkAddresses) != len(b.NetworkAddresses) {
return false
}
for _, addr := range a.GetNetworkAddresses() {
var found bool
for _, oAddr := range b.GetNetworkAddresses() {
if addr.GetMac() == oAddr.GetMac() && addr.GetNetIP() == oAddr.GetNetIP() {
found = true
continue
}
}
if !found {
return false
}
}
log.Infof("------")
return a.GetHostname() == b.GetHostname() &&
a.GetGoOS() == b.GetGoOS() &&
a.GetKernel() == b.GetKernel() &&
a.GetKernelVersion() == b.GetKernelVersion() &&
a.GetCore() == b.GetCore() &&
a.GetPlatform() == b.GetPlatform() &&
a.GetOS() == b.GetOS() &&
a.GetOSVersion() == b.GetOSVersion() &&
a.GetNetbirdVersion() == b.GetNetbirdVersion() &&
a.GetUiVersion() == b.GetUiVersion() &&
a.GetSysSerialNumber() == b.GetSysSerialNumber() &&
a.GetSysProductName() == b.GetSysProductName() &&
a.GetSysManufacturer() == b.GetSysManufacturer() &&
a.GetEnvironment().Cloud == b.GetEnvironment().Cloud &&
a.GetEnvironment().Platform == b.GetEnvironment().Platform
}
func Test_GetDeviceAuthorizationFlow(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
serverAddr := lis.Addr().String()
ctx := context.Background()
client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
expectedFlowInfo := &mgmtProto.DeviceAuthorizationFlow{
Provider: 0,
ProviderConfig: &mgmtProto.ProviderConfig{ClientID: "client"},
}
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
if err != nil {
t.Error("error while retrieving device auth flow information")
}
assert.Equal(t, expectedFlowInfo.Provider, flowInfo.Provider, "provider should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
}
func Test_GetPKCEAuthorizationFlow(t *testing.T) {
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
defer s.GracefulStop()
testKey, err := wgtypes.GenerateKey()
if err != nil {
t.Fatal(err)
}
serverAddr := lis.Addr().String()
ctx := context.Background()
client, err := NewClient(ctx, serverAddr, testKey, false)
if err != nil {
t.Fatalf("error while creating testClient: %v", err)
}
expectedFlowInfo := &mgmtProto.PKCEAuthorizationFlow{
ProviderConfig: &mgmtProto.ProviderConfig{
ClientID: "client",
ClientSecret: "secret",
},
}
mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
if err != nil {
return nil, err
}
return &mgmtProto.EncryptedMessage{
WgPubKey: serverKey.PublicKey().String(),
Body: encryptedResp,
Version: 0,
}, nil
}
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey)
if err != nil {
t.Error("error while retrieving pkce auth flow information")
}
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match")
}

View File

@@ -0,0 +1,19 @@
package common
// LoginFlag introduces additional login flags to the PKCE authorization request
type LoginFlag uint8
const (
// LoginFlagPrompt adds prompt=login to the authorization request
LoginFlagPrompt LoginFlag = iota
// LoginFlagMaxAge0 adds max_age=0 to the authorization request
LoginFlagMaxAge0
)
func (l LoginFlag) IsPromptLogin() bool {
return l == LoginFlagPrompt
}
func (l LoginFlag) IsMaxAge0Login() bool {
return l == LoginFlagMaxAge0
}

View File

@@ -0,0 +1,3 @@
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=

View File

@@ -0,0 +1,584 @@
package client
import (
"context"
"errors"
"fmt"
"io"
"sync"
"time"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
nbgrpc "github.com/netbirdio/netbird/util/grpc"
)
const ConnectTimeout = 10 * time.Second
const (
errMsgMgmtPublicKey = "failed getting Management Service public key: %s"
errMsgNoMgmtConnection = "no connection to management"
)
// ConnStateNotifier is a wrapper interface of the status recorders
type ConnStateNotifier interface {
MarkManagementDisconnected(error)
MarkManagementConnected()
}
type GrpcClient struct {
key wgtypes.Key
realClient proto.ManagementServiceClient
ctx context.Context
conn *grpc.ClientConn
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
}
// NewClient creates a new client to Management service
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
var conn *grpc.ClientConn
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
}
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed creating connection to Management Service: %v", err)
return nil, err
}
realClient := proto.NewManagementServiceClient(conn)
return &GrpcClient{
key: ourPrivateKey,
realClient: realClient,
ctx: ctx,
conn: conn,
connStateCallbackLock: sync.RWMutex{},
}, nil
}
// Close closes connection to the Management Service
func (c *GrpcClient) Close() error {
return c.conn.Close()
}
// SetConnStateListener set the ConnStateNotifier
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallbackLock.Lock()
defer c.connStateCallbackLock.Unlock()
c.connStateCallback = notifier
}
// defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 1,
Multiplier: 1.7,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// ready indicates whether the client is okay and ready to be used
// for now it just checks whether gRPC connection to the service is ready
func (c *GrpcClient) ready() bool {
return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle
}
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState())
connState := c.conn.GetState()
if connState == connectivity.Shutdown {
return backoff.Permanent(fmt.Errorf("connection to management has been shut down"))
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
c.conn.WaitForStateChange(ctx, connState)
return fmt.Errorf("connection to management is not ready and in %s state", connState)
}
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
}
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler)
}
err := backoff.Retry(operation, defaultBackoff(ctx))
if err != nil {
log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err)
}
return err
}
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
msgHandler func(msg *proto.SyncResponse) error) error {
ctx, cancelStream := context.WithCancel(ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
}
return err
}
log.Infof("connected to the Management Service stream")
c.notifyConnected()
// blocking until error
err = c.receiveEvents(stream, serverPubKey, msgHandler)
if err != nil {
c.notifyDisconnected(err)
s, _ := gstatus.FromError(err)
switch s.Code() {
case codes.PermissionDenied:
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
default:
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
}
return nil
}
// GetNetworkMap return with the network map
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf("failed getting Management Service public key: %s", err)
return nil, err
}
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
return nil, err
}
defer func() {
_ = stream.CloseSend()
}()
update, err := stream.Recv()
if err == io.EOF {
log.Debugf("Management stream has been closed by server: %s", err)
return nil, err
}
if err != nil {
log.Debugf("disconnected from Management Service sync stream: %v", err)
return nil, err
}
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(*serverPubKey, c.key, update.Body, decryptedResp)
if err != nil {
log.Errorf("failed decrypting update message from Management Service: %s", err)
return nil, err
}
if decryptedResp.GetNetworkMap() == nil {
return nil, fmt.Errorf("invalid msg, required network map")
}
return decryptedResp.GetNetworkMap(), nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)}
myPrivateKey := c.key
myPublicKey := myPrivateKey.PublicKey()
encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req)
if err != nil {
log.Errorf("failed encrypting message: %s", err)
return nil, err
}
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
sync, err := c.realClient.Sync(ctx, syncReq)
if err != nil {
return nil, err
}
return sync, nil
}
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
for {
update, err := stream.Recv()
if err == io.EOF {
log.Debugf("Management stream has been closed by server: %s", err)
return err
}
if err != nil {
log.Debugf("disconnected from Management Service sync stream: %v", err)
return err
}
log.Debugf("got an update message from Management Service")
decryptedResp := &proto.SyncResponse{}
err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp)
if err != nil {
log.Errorf("failed decrypting update message from Management Service: %s", err)
return err
}
if err := msgHandler(decryptedResp); err != nil {
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
}
}
}
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
}
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
defer cancel()
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
return nil, fmt.Errorf("failed while getting Management Service public key")
}
serverKey, err := wgtypes.ParseKey(resp.Key)
if err != nil {
return nil, err
}
return &serverKey, nil
}
// IsHealthy probes the gRPC connection and returns false on errors
func (c *GrpcClient) IsHealthy() bool {
switch c.conn.GetState() {
case connectivity.TransientFailure:
return false
case connectivity.Connecting:
return true
case connectivity.Shutdown:
return true
case connectivity.Idle:
case connectivity.Ready:
}
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
defer cancel()
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
if err != nil {
c.notifyDisconnected(err)
log.Warnf("health check returned: %s", err)
return false
}
c.notifyConnected()
return true
}
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
if !c.ready() {
return nil, errors.New(errMsgNoMgmtConnection)
}
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return nil, err
}
var resp *proto.EncryptedMessage
operation := func() error {
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
defer cancel()
var err error
resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: loginReq,
})
if err != nil {
// retry only on context canceled
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled {
return err
}
return backoff.Permanent(err)
}
return nil
}
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
if err != nil {
log.Errorf("failed to login to Management Service: %v", err)
return nil, err
}
loginResp := &proto.LoginResponse{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
if err != nil {
log.Errorf("failed to decrypt login response: %s", err)
return nil, err
}
return loginResp, nil
}
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
}
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
}
// GetDeviceAuthorizationFlow returns a device authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
message := &proto.DeviceAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
if err != nil {
return nil, err
}
resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG},
)
if err != nil {
return nil, err
}
flowInfoResp := &proto.DeviceAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
log.Error(errWithMSG)
return nil, errWithMSG
}
return flowInfoResp, nil
}
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
// It also takes care of encrypting and decrypting messages.
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
if !c.ready() {
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
defer cancel()
message := &proto.PKCEAuthorizationFlowRequest{}
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
if err != nil {
return nil, err
}
resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG,
})
if err != nil {
return nil, err
}
flowInfoResp := &proto.PKCEAuthorizationFlow{}
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
if err != nil {
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
log.Error(errWithMSG)
return nil, errWithMSG
}
return flowInfoResp, nil
}
// SyncMeta sends updated system metadata to the Management Service.
// It should be used if there is changes on peer posture check after initial sync.
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
if !c.ready() {
return errors.New(errMsgNoMgmtConnection)
}
serverPubKey, err := c.GetServerPublicKey()
if err != nil {
log.Debugf(errMsgMgmtPublicKey, err)
return err
}
syncMetaReq, err := encryption.EncryptMessage(*serverPubKey, c.key, &proto.SyncMetaRequest{Meta: infoToMetaData(sysInfo)})
if err != nil {
log.Errorf("failed to encrypt message: %s", err)
return err
}
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
defer cancel()
_, err = c.realClient.SyncMeta(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: syncMetaReq,
})
return err
}
func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkManagementDisconnected(err)
}
func (c *GrpcClient) notifyConnected() {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkManagementConnected()
}
func (c *GrpcClient) Logout() error {
serverKey, err := c.GetServerPublicKey()
if err != nil {
return fmt.Errorf("get server public key: %w", err)
}
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*5)
defer cancel()
message := &proto.Empty{}
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
if err != nil {
return fmt.Errorf("encrypt logout message: %w", err)
}
_, err = c.realClient.Logout(mgmCtx, &proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encryptedMSG,
})
if err != nil {
return fmt.Errorf("logout: %w", err)
}
return nil
}
func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
if info == nil {
return nil
}
addresses := make([]*proto.NetworkAddress, 0, len(info.NetworkAddresses))
for _, addr := range info.NetworkAddresses {
addresses = append(addresses, &proto.NetworkAddress{
NetIP: addr.NetIP.String(),
Mac: addr.Mac,
})
}
files := make([]*proto.File, 0, len(info.Files))
for _, file := range info.Files {
files = append(files, &proto.File{
Path: file.Path,
Exist: file.Exist,
ProcessIsRunning: file.ProcessIsRunning,
})
}
return &proto.PeerSystemMeta{
Hostname: info.Hostname,
GoOS: info.GoOS,
OS: info.OS,
Core: info.OSVersion,
OSVersion: info.OSVersion,
Platform: info.Platform,
Kernel: info.Kernel,
NetbirdVersion: info.NetbirdVersion,
UiVersion: info.UIVersion,
KernelVersion: info.KernelVersion,
NetworkAddresses: addresses,
SysSerialNumber: info.SystemSerialNumber,
SysManufacturer: info.SystemManufacturer,
SysProductName: info.SystemProductName,
Environment: &proto.Environment{
Cloud: info.Environment.Cloud,
Platform: info.Environment.Platform,
},
Files: files,
Flags: &proto.Flags{
RosenpassEnabled: info.RosenpassEnabled,
RosenpassPermissive: info.RosenpassPermissive,
ServerSSHAllowed: info.ServerSSHAllowed,
DisableClientRoutes: info.DisableClientRoutes,
DisableServerRoutes: info.DisableServerRoutes,
DisableDNS: info.DisableDNS,
DisableFirewall: info.DisableFirewall,
BlockLANAccess: info.BlockLANAccess,
BlockInbound: info.BlockInbound,
LazyConnectionEnabled: info.LazyConnectionEnabled,
},
}
}

View File

@@ -0,0 +1,95 @@
package client
import (
"context"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
)
type MockClient struct {
CloseFunc func() error
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(sysInfo *system.Info) error
LogoutFunc func() error
}
func (m *MockClient) IsHealthy() bool {
return true
}
func (m *MockClient) Close() error {
if m.CloseFunc == nil {
return nil
}
return m.CloseFunc()
}
func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
if m.SyncFunc == nil {
return nil
}
return m.SyncFunc(ctx, sysInfo, msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
if m.GetServerPublicKeyFunc == nil {
return nil, nil
}
return m.GetServerPublicKeyFunc()
}
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.RegisterFunc == nil {
return nil, nil
}
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
}
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
if m.LoginFunc == nil {
return nil, nil
}
return m.LoginFunc(serverKey, info, sshKey, dnsLabels)
}
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
if m.GetDeviceAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetDeviceAuthorizationFlowFunc(serverKey)
}
func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
if m.GetPKCEAuthorizationFlowFunc == nil {
return nil, nil
}
return m.GetPKCEAuthorizationFlow(serverKey)
}
// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface
func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) {
return nil, nil
}
func (m *MockClient) SyncMeta(sysInfo *system.Info) error {
if m.SyncMetaFunc == nil {
return nil
}
return m.SyncMetaFunc(sysInfo)
}
func (m *MockClient) Logout() error {
if m.LogoutFunc == nil {
return nil
}
return m.LogoutFunc()
}

View File

@@ -0,0 +1,60 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// AccountsAPI APIs for accounts, do not use directly
type AccountsAPI struct {
c *Client
}
// List list all accounts, only returns one account always
// See more: https://docs.netbird.io/api/resources/accounts#list-all-accounts
func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/accounts", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Account](resp)
return ret, err
}
// Update update account settings
// See more: https://docs.netbird.io/api/resources/accounts#update-an-account
func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.PutApiAccountsAccountIdJSONRequestBody) (*api.Account, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Account](resp)
return &ret, err
}
// Delete delete account
// See more: https://docs.netbird.io/api/resources/accounts#delete-an-account
func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,183 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testAccount = api.Account{
Id: "Test",
Settings: api.AccountSettings{
Extra: &api.AccountExtraSettings{
PeerApprovalEnabled: false,
},
GroupsPropagationEnabled: ptr(true),
JwtGroupsEnabled: ptr(false),
PeerInactivityExpiration: 7,
PeerInactivityExpirationEnabled: true,
PeerLoginExpiration: 24,
PeerLoginExpirationEnabled: true,
RegularUsersViewBlocked: false,
RoutingPeerDnsResolutionEnabled: ptr(false),
},
}
)
func TestAccounts_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Account{testAccount})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Accounts.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testAccount, ret[0])
})
}
func TestAccounts_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Accounts.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestAccounts_List_ConnErr(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
ret, err := c.Accounts.List(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "404")
assert.Empty(t, ret)
})
}
func TestAccounts_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiAccountsAccountIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, *req.Settings.RoutingPeerDnsResolutionEnabled)
retBytes, _ := json.Marshal(testAccount)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Accounts.Update(context.Background(), "Test", api.PutApiAccountsAccountIdJSONRequestBody{
Settings: api.AccountSettings{
RoutingPeerDnsResolutionEnabled: ptr(true),
},
})
require.NoError(t, err)
assert.Equal(t, testAccount, *ret)
})
}
func TestAccounts_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Accounts.Update(context.Background(), "Test", api.PutApiAccountsAccountIdJSONRequestBody{
Settings: api.AccountSettings{
RoutingPeerDnsResolutionEnabled: ptr(true),
},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestAccounts_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Accounts.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestAccounts_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Accounts.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestAccounts_Integration_List(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
accounts, err := c.Accounts.List(context.Background())
require.NoError(t, err)
assert.Len(t, accounts, 1)
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", accounts[0].Id)
assert.Equal(t, false, accounts[0].Settings.Extra.PeerApprovalEnabled)
})
}
func TestAccounts_Integration_Update(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
accounts, err := c.Accounts.List(context.Background())
require.NoError(t, err)
assert.Len(t, accounts, 1)
accounts[0].Settings.JwtAllowGroups = ptr([]string{"test"})
account, err := c.Accounts.Update(context.Background(), accounts[0].Id, api.AccountRequest{
Settings: accounts[0].Settings,
})
require.NoError(t, err)
assert.Equal(t, accounts[0].Id, account.Id)
assert.Equal(t, []string{"test"}, *account.Settings.JwtAllowGroups)
})
}
// Account deletion on MySQL and PostgreSQL databases causes unknown errors
// func TestAccounts_Integration_Delete(t *testing.T) {
// withBlackBoxServer(t, func(c *rest.Client) {
// accounts, err := c.Accounts.List(context.Background())
// require.NoError(t, err)
// assert.Len(t, accounts, 1)
// err = c.Accounts.Delete(context.Background(), accounts[0].Id)
// require.NoError(t, err)
// _, err = c.Accounts.List(context.Background())
// assert.Error(t, err)
// })
// }

View File

@@ -0,0 +1,172 @@
package rest
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"github.com/netbirdio/netbird/management/server/http/util"
)
// Client Management service HTTP REST API Client
type Client struct {
managementURL string
authHeader string
httpClient HttpClient
// Accounts NetBird account APIs
// see more: https://docs.netbird.io/api/resources/accounts
Accounts *AccountsAPI
// Users NetBird users APIs
// see more: https://docs.netbird.io/api/resources/users
Users *UsersAPI
// Tokens NetBird tokens APIs
// see more: https://docs.netbird.io/api/resources/tokens
Tokens *TokensAPI
// Peers NetBird peers APIs
// see more: https://docs.netbird.io/api/resources/peers
Peers *PeersAPI
// SetupKeys NetBird setup keys APIs
// see more: https://docs.netbird.io/api/resources/setup-keys
SetupKeys *SetupKeysAPI
// Groups NetBird groups APIs
// see more: https://docs.netbird.io/api/resources/groups
Groups *GroupsAPI
// Policies NetBird policies APIs
// see more: https://docs.netbird.io/api/resources/policies
Policies *PoliciesAPI
// PostureChecks NetBird posture checks APIs
// see more: https://docs.netbird.io/api/resources/posture-checks
PostureChecks *PostureChecksAPI
// Networks NetBird networks APIs
// see more: https://docs.netbird.io/api/resources/networks
Networks *NetworksAPI
// Routes NetBird routes APIs
// see more: https://docs.netbird.io/api/resources/routes
Routes *RoutesAPI
// DNS NetBird DNS APIs
// see more: https://docs.netbird.io/api/resources/routes
DNS *DNSAPI
// GeoLocation NetBird Geo Location APIs
// see more: https://docs.netbird.io/api/resources/geo-locations
GeoLocation *GeoLocationAPI
// Events NetBird Events APIs
// see more: https://docs.netbird.io/api/resources/events
Events *EventsAPI
}
// New initialize new Client instance using PAT token
func New(managementURL, token string) *Client {
return NewWithOptions(
WithManagementURL(managementURL),
WithPAT(token),
)
}
// NewWithBearerToken initialize new Client instance using Bearer token type
func NewWithBearerToken(managementURL, token string) *Client {
return NewWithOptions(
WithManagementURL(managementURL),
WithBearerToken(token),
)
}
// NewWithOptions initialize new Client instance with options
func NewWithOptions(opts ...option) *Client {
client := &Client{
httpClient: http.DefaultClient,
}
for _, option := range opts {
option(client)
}
client.initialize()
return client
}
func (c *Client) initialize() {
c.Accounts = &AccountsAPI{c}
c.Users = &UsersAPI{c}
c.Tokens = &TokensAPI{c}
c.Peers = &PeersAPI{c}
c.SetupKeys = &SetupKeysAPI{c}
c.Groups = &GroupsAPI{c}
c.Policies = &PoliciesAPI{c}
c.PostureChecks = &PostureChecksAPI{c}
c.Networks = &NetworksAPI{c}
c.Routes = &RoutesAPI{c}
c.DNS = &DNSAPI{c}
c.GeoLocation = &GeoLocationAPI{c}
c.Events = &EventsAPI{c}
}
// NewRequest creates and executes new management API request
func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Reader, query map[string]string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, method, c.managementURL+path, body)
if err != nil {
return nil, err
}
req.Header.Add("Authorization", c.authHeader)
req.Header.Add("Accept", "application/json")
if body != nil {
req.Header.Add("Content-Type", "application/json")
}
if len(query) != 0 {
q := req.URL.Query()
for k, v := range query {
q.Add(k, v)
}
req.URL.RawQuery = q.Encode()
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode > 299 {
parsedErr, pErr := parseResponse[util.ErrorResponse](resp)
if pErr != nil {
return nil, pErr
}
return nil, errors.New(parsedErr.Message)
}
return resp, nil
}
func parseResponse[T any](resp *http.Response) (T, error) {
var ret T
if resp.Body == nil {
return ret, fmt.Errorf("Body missing, HTTP Error code %d", resp.StatusCode)
}
bs, err := io.ReadAll(resp.Body)
if err != nil {
return ret, err
}
err = json.Unmarshal(bs, &ret)
if err != nil {
return ret, fmt.Errorf("Error code %d, error unmarshalling body: %w", resp.StatusCode, err)
}
return ret, nil
}

View File

@@ -0,0 +1,34 @@
//go:build integration
// +build integration
package rest_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
)
func withMockClient(callback func(*rest.Client, *http.ServeMux)) {
mux := &http.ServeMux{}
server := httptest.NewServer(mux)
defer server.Close()
c := rest.New(server.URL, "ABC")
callback(c, mux)
}
func ptr[T any, PT *T](x T) PT {
return &x
}
func withBlackBoxServer(t *testing.T, callback func(*rest.Client)) {
t.Helper()
handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../server/testdata/store.sql", nil, false)
server := httptest.NewServer(handler)
defer server.Close()
c := rest.New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp")
callback(c)
}

View File

@@ -0,0 +1,124 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// DNSAPI APIs for DNS Management, do not use directly
type DNSAPI struct {
c *Client
}
// ListNameserverGroups list all nameserver groups
// See more: https://docs.netbird.io/api/resources/dns#list-all-nameserver-groups
func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGroup, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.NameserverGroup](resp)
return ret, err
}
// GetNameserverGroup get nameserver group info
// See more: https://docs.netbird.io/api/resources/dns#retrieve-a-nameserver-group
func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID string) (*api.NameserverGroup, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NameserverGroup](resp)
return &ret, err
}
// CreateNameserverGroup create new nameserver group
// See more: https://docs.netbird.io/api/resources/dns#create-a-nameserver-group
func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiDnsNameserversJSONRequestBody) (*api.NameserverGroup, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NameserverGroup](resp)
return &ret, err
}
// UpdateNameserverGroup update nameserver group info
// See more: https://docs.netbird.io/api/resources/dns#update-a-nameserver-group
func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID string, request api.PutApiDnsNameserversNsgroupIdJSONRequestBody) (*api.NameserverGroup, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NameserverGroup](resp)
return &ret, err
}
// DeleteNameserverGroup delete nameserver group
// See more: https://docs.netbird.io/api/resources/dns#delete-a-nameserver-group
func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// GetSettings get DNS settings
// See more: https://docs.netbird.io/api/resources/dns#retrieve-dns-settings
func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/settings", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.DNSSettings](resp)
return &ret, err
}
// UpdateSettings update DNS settings
// See more: https://docs.netbird.io/api/resources/dns#update-dns-settings
func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettingsJSONRequestBody) (*api.DNSSettings, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.DNSSettings](resp)
return &ret, err
}

View File

@@ -0,0 +1,301 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testNameserverGroup = api.NameserverGroup{
Id: "Test",
Name: "wow",
}
testSettings = api.DNSSettings{
DisabledManagementGroups: []string{"gone"},
}
)
func TestDNSNameserverGroup_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.NameserverGroup{testNameserverGroup})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.ListNameserverGroups(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testNameserverGroup, ret[0])
})
}
func TestDNSNameserverGroup_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.ListNameserverGroups(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSNameserverGroup_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testNameserverGroup)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.GetNameserverGroup(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testNameserverGroup, *ret)
})
}
func TestDNSNameserverGroup_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.GetNameserverGroup(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSNameserverGroup_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiDnsNameserversJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testNameserverGroup)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.CreateNameserverGroup(context.Background(), api.PostApiDnsNameserversJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testNameserverGroup, *ret)
})
}
func TestDNSNameserverGroup_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.CreateNameserverGroup(context.Background(), api.PostApiDnsNameserversJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestDNSNameserverGroup_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testNameserverGroup)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.UpdateNameserverGroup(context.Background(), "Test", api.PutApiDnsNameserversNsgroupIdJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testNameserverGroup, *ret)
})
}
func TestDNSNameserverGroup_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.UpdateNameserverGroup(context.Background(), "Test", api.PutApiDnsNameserversNsgroupIdJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestDNSNameserverGroup_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.DNS.DeleteNameserverGroup(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestDNSNameserverGroup_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.DNS.DeleteNameserverGroup(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestDNSSettings_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testSettings)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.GetSettings(context.Background())
require.NoError(t, err)
assert.Equal(t, testSettings, *ret)
})
}
func TestDNSSettings_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.GetSettings(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSSettings_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiDnsSettingsJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, []string{"test"}, req.DisabledManagementGroups)
retBytes, _ := json.Marshal(testSettings)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.UpdateSettings(context.Background(), api.PutApiDnsSettingsJSONRequestBody{
DisabledManagementGroups: []string{"test"},
})
require.NoError(t, err)
assert.Equal(t, testSettings, *ret)
})
}
func TestDNSSettings_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNS.UpdateSettings(context.Background(), api.PutApiDnsSettingsJSONRequestBody{
DisabledManagementGroups: []string{"test"},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestDNS_Integration(t *testing.T) {
nsGroupReq := api.NameserverGroupRequest{
Description: "Test",
Enabled: true,
Domains: []string{},
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
Name: "test",
Nameservers: []api.Nameserver{
{
Ip: "8.8.8.8",
NsType: api.NameserverNsTypeUdp,
Port: 53,
},
},
Primary: true,
SearchDomainsEnabled: false,
}
withBlackBoxServer(t, func(c *rest.Client) {
// Create
nsGroup, err := c.DNS.CreateNameserverGroup(context.Background(), nsGroupReq)
require.NoError(t, err)
// List
nsGroups, err := c.DNS.ListNameserverGroups(context.Background())
require.NoError(t, err)
assert.Equal(t, *nsGroup, nsGroups[0])
// Update
nsGroupReq.Description = "TestUpdate"
nsGroup, err = c.DNS.UpdateNameserverGroup(context.Background(), nsGroup.Id, nsGroupReq)
require.NoError(t, err)
assert.Equal(t, "TestUpdate", nsGroup.Description)
// Delete
err = c.DNS.DeleteNameserverGroup(context.Background(), nsGroup.Id)
require.NoError(t, err)
// List again to ensure deletion
nsGroups, err = c.DNS.ListNameserverGroups(context.Background())
require.NoError(t, err)
assert.Len(t, nsGroups, 0)
})
}

View File

@@ -0,0 +1,26 @@
package rest
import (
"context"
"github.com/netbirdio/netbird/management/server/http/api"
)
// EventsAPI APIs for Events, do not use directly
type EventsAPI struct {
c *Client
}
// List list all events
// See more: https://docs.netbird.io/api/resources/events#list-all-events
func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Event](resp)
return ret, err
}

View File

@@ -0,0 +1,70 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testEvent = api.Event{
Activity: "AccountCreate",
ActivityCode: api.EventActivityCodeAccountCreate,
}
)
func TestEvents_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Event{testEvent})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Events.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testEvent, ret[0])
})
}
func TestEvents_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Events.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestEvents_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
// Do something that would trigger any event
_, err := c.SetupKeys.Create(context.Background(), api.CreateSetupKeyRequest{
Ephemeral: ptr(true),
Name: "TestSetupKey",
Type: "reusable",
})
require.NoError(t, err)
events, err := c.Events.List(context.Background())
require.NoError(t, err)
assert.NotEmpty(t, events)
})
}

View File

@@ -0,0 +1,40 @@
package rest
import (
"context"
"github.com/netbirdio/netbird/management/server/http/api"
)
// GeoLocationAPI APIs for Geo-Location, do not use directly
type GeoLocationAPI struct {
c *Client
}
// ListCountries list all country codes
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-country-codes
func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Country](resp)
return ret, err
}
// ListCountryCities Get a list of all English city names for a given country code
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-city-names-by-country
func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode string) ([]api.City, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.City](resp)
return ret, err
}

View File

@@ -0,0 +1,101 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testCountry = api.Country{
CountryCode: "DE",
CountryName: "Germany",
}
testCity = api.City{
CityName: "Berlin",
GeonameId: 2950158,
}
)
func TestGeo_ListCountries_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Country{testCountry})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GeoLocation.ListCountries(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testCountry, ret[0])
})
}
func TestGeo_ListCountries_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GeoLocation.ListCountries(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestGeo_ListCountryCities_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.City{testCity})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GeoLocation.ListCountryCities(context.Background(), "Test")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testCity, ret[0])
})
}
func TestGeo_ListCountryCities_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.GeoLocation.ListCountryCities(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestGeo_Integration(t *testing.T) {
// Blackbox is initialized with empty GeoLocations
withBlackBoxServer(t, func(c *rest.Client) {
countries, err := c.GeoLocation.ListCountries(context.Background())
require.NoError(t, err)
assert.Empty(t, countries)
cities, err := c.GeoLocation.ListCountryCities(context.Background(), "DE")
require.NoError(t, err)
assert.Empty(t, cities)
})
}

View File

@@ -0,0 +1,92 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// GroupsAPI APIs for Groups, do not use directly
type GroupsAPI struct {
c *Client
}
// List list all groups
// See more: https://docs.netbird.io/api/resources/groups#list-all-groups
func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Group](resp)
return ret, err
}
// Get get group info
// See more: https://docs.netbird.io/api/resources/groups#retrieve-a-group
func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups/"+groupID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Group](resp)
return &ret, err
}
// Create create new group
// See more: https://docs.netbird.io/api/resources/groups#create-a-group
func (a *GroupsAPI) Create(ctx context.Context, request api.PostApiGroupsJSONRequestBody) (*api.Group, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/groups", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Group](resp)
return &ret, err
}
// Update update group info
// See more: https://docs.netbird.io/api/resources/groups#update-a-group
func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutApiGroupsGroupIdJSONRequestBody) (*api.Group, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/groups/"+groupID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Group](resp)
return &ret, err
}
// Delete delete group
// See more: https://docs.netbird.io/api/resources/groups#delete-a-group
func (a *GroupsAPI) Delete(ctx context.Context, groupID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/groups/"+groupID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,215 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testGroup = api.Group{
Id: "Test",
Name: "wow",
PeersCount: 0,
}
)
func TestGroups_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Group{testGroup})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testGroup, ret[0])
})
}
func TestGroups_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestGroups_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testGroup)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testGroup, *ret)
})
}
func TestGroups_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestGroups_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiGroupsJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testGroup)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.Create(context.Background(), api.PostApiGroupsJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testGroup, *ret)
})
}
func TestGroups_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.Create(context.Background(), api.PostApiGroupsJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestGroups_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiGroupsGroupIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testGroup)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.Update(context.Background(), "Test", api.PutApiGroupsGroupIdJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testGroup, *ret)
})
}
func TestGroups_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Groups.Update(context.Background(), "Test", api.PutApiGroupsGroupIdJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestGroups_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Groups.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestGroups_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Groups.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestGroups_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
groups, err := c.Groups.List(context.Background())
require.NoError(t, err)
assert.Len(t, groups, 1)
group, err := c.Groups.Create(context.Background(), api.GroupRequest{
Name: "Test",
})
require.NoError(t, err)
assert.Equal(t, "Test", group.Name)
assert.NotEmpty(t, group.Id)
group, err = c.Groups.Update(context.Background(), group.Id, api.GroupRequest{
Name: "Testnt",
})
require.NoError(t, err)
assert.Equal(t, "Testnt", group.Name)
err = c.Groups.Delete(context.Background(), group.Id)
require.NoError(t, err)
groups, err = c.Groups.List(context.Background())
require.NoError(t, err)
assert.Len(t, groups, 1)
})
}

View File

@@ -0,0 +1,48 @@
package rest
import (
"net/http"
"net/url"
)
// Impersonate returns a Client impersonated for a specific account
func (c *Client) Impersonate(account string) *Client {
client := NewWithOptions(
WithManagementURL(c.managementURL),
WithAuthHeader(c.authHeader),
WithHttpClient(newImpersonatedHttpClient(c, account)),
)
return client
}
type impersonatedHttpClient struct {
baseClient HttpClient
account string
}
func newImpersonatedHttpClient(c *Client, account string) *impersonatedHttpClient {
if hc, ok := c.httpClient.(*impersonatedHttpClient); ok {
hc.account = account
return hc
}
return &impersonatedHttpClient{
baseClient: c.httpClient,
account: account,
}
}
func (c *impersonatedHttpClient) Do(req *http.Request) (*http.Response, error) {
parsedURL, err := url.Parse(req.URL.String())
if err != nil {
return nil, err
}
query := parsedURL.Query()
query.Set("account", c.account)
parsedURL.RawQuery = query.Encode()
req.URL = parsedURL
return c.baseClient.Do(req)
}

View File

@@ -0,0 +1,77 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
)
var (
testImpersonatedAccount = api.Account{
Id: "ImpersonatedTest",
Settings: api.AccountSettings{
Extra: &api.AccountExtraSettings{
PeerApprovalEnabled: false,
},
GroupsPropagationEnabled: ptr(true),
JwtGroupsEnabled: ptr(false),
PeerInactivityExpiration: 7,
PeerInactivityExpirationEnabled: true,
PeerLoginExpiration: 24,
PeerLoginExpirationEnabled: true,
RegularUsersViewBlocked: false,
RoutingPeerDnsResolutionEnabled: ptr(false),
},
}
)
func TestImpersonation_Peers_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
impersonatedClient := c.Impersonate(testImpersonatedAccount.Id)
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, r.URL.Query().Get("account"), testImpersonatedAccount.Id)
retBytes, _ := json.Marshal([]api.Peer{testPeer})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := impersonatedClient.Peers.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testPeer, ret[0])
})
}
func TestImpersonation_Change_Account(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
impersonatedClient := c.Impersonate(testImpersonatedAccount.Id)
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, r.URL.Query().Get("account"), testImpersonatedAccount.Id)
retBytes, _ := json.Marshal([]api.Peer{testPeer})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
_, err := impersonatedClient.Peers.List(context.Background())
require.NoError(t, err)
impersonatedClient = impersonatedClient.Impersonate("another-test-account")
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, r.URL.Query().Get("account"), "another-test-account")
retBytes, _ := json.Marshal(testPeer)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
_, err = impersonatedClient.Peers.Get(context.Background(), "Test")
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,276 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// NetworksAPI APIs for Networks, do not use directly
type NetworksAPI struct {
c *Client
}
// List list all networks
// See more: https://docs.netbird.io/api/resources/networks#list-all-networks
func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Network](resp)
return ret, err
}
// Get get network info
// See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network
func (a *NetworksAPI) Get(ctx context.Context, networkID string) (*api.Network, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+networkID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Network](resp)
return &ret, err
}
// Create create new network
// See more: https://docs.netbird.io/api/resources/networks#create-a-network
func (a *NetworksAPI) Create(ctx context.Context, request api.PostApiNetworksJSONRequestBody) (*api.Network, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Network](resp)
return &ret, err
}
// Update update network
// See more: https://docs.netbird.io/api/resources/networks#update-a-network
func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api.PutApiNetworksNetworkIdJSONRequestBody) (*api.Network, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+networkID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Network](resp)
return &ret, err
}
// Delete delete network
// See more: https://docs.netbird.io/api/resources/networks#delete-a-network
func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+networkID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// NetworkResourcesAPI APIs for Network Resources, do not use directly
type NetworkResourcesAPI struct {
c *Client
networkID string
}
// Resources APIs for network resources
func (a *NetworksAPI) Resources(networkID string) *NetworkResourcesAPI {
return &NetworkResourcesAPI{
c: a.c,
networkID: networkID,
}
}
// List list all resources in networks
// See more: https://docs.netbird.io/api/resources/networks#list-all-network-resources
func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.NetworkResource](resp)
return ret, err
}
// Get get network resource info
// See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network-resource
func (a *NetworkResourcesAPI) Get(ctx context.Context, networkResourceID string) (*api.NetworkResource, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkResource](resp)
return &ret, err
}
// Create create new network resource
// See more: https://docs.netbird.io/api/resources/networks#create-a-network-resource
func (a *NetworkResourcesAPI) Create(ctx context.Context, request api.PostApiNetworksNetworkIdResourcesJSONRequestBody) (*api.NetworkResource, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/resources", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkResource](resp)
return &ret, err
}
// Update update network resource
// See more: https://docs.netbird.io/api/resources/networks#update-a-network-resource
func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID string, request api.PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody) (*api.NetworkResource, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkResource](resp)
return &ret, err
}
// Delete delete network resource
// See more: https://docs.netbird.io/api/resources/networks#delete-a-network-resource
func (a *NetworkResourcesAPI) Delete(ctx context.Context, networkResourceID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// NetworkRoutersAPI APIs for Network Routers, do not use directly
type NetworkRoutersAPI struct {
c *Client
networkID string
}
// Routers APIs for network routers
func (a *NetworksAPI) Routers(networkID string) *NetworkRoutersAPI {
return &NetworkRoutersAPI{
c: a.c,
networkID: networkID,
}
}
// List list all routers in networks
// See more: https://docs.netbird.io/api/routers/networks#list-all-network-routers
func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.NetworkRouter](resp)
return ret, err
}
// Get get network router info
// See more: https://docs.netbird.io/api/routers/networks#retrieve-a-network-router
func (a *NetworkRoutersAPI) Get(ctx context.Context, networkRouterID string) (*api.NetworkRouter, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkRouter](resp)
return &ret, err
}
// Create create new network router
// See more: https://docs.netbird.io/api/routers/networks#create-a-network-router
func (a *NetworkRoutersAPI) Create(ctx context.Context, request api.PostApiNetworksNetworkIdRoutersJSONRequestBody) (*api.NetworkRouter, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/routers", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkRouter](resp)
return &ret, err
}
// Update update network router
// See more: https://docs.netbird.io/api/routers/networks#update-a-network-router
func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string, request api.PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody) (*api.NetworkRouter, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.NetworkRouter](resp)
return &ret, err
}
// Delete delete network router
// See more: https://docs.netbird.io/api/routers/networks#delete-a-network-router
func (a *NetworkRoutersAPI) Delete(ctx context.Context, networkRouterID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,594 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testNetwork = api.Network{
Id: "Test",
Name: "wow",
}
testNetworkResource = api.NetworkResource{
Description: ptr("meaw"),
Id: "awa",
}
testNetworkRouter = api.NetworkRouter{
Id: "ouch",
}
)
func TestNetworks_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Network{testNetwork})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testNetwork, ret[0])
})
}
func TestNetworks_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworks_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testNetwork)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testNetwork, *ret)
})
}
func TestNetworks_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworks_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiNetworksJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testNetwork)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Create(context.Background(), api.PostApiNetworksJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testNetwork, *ret)
})
}
func TestNetworks_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Create(context.Background(), api.PostApiNetworksJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestNetworks_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiNetworksNetworkIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testNetwork)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Update(context.Background(), "Test", api.PutApiNetworksNetworkIdJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testNetwork, *ret)
})
}
func TestNetworks_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Update(context.Background(), "Test", api.PutApiNetworksNetworkIdJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestNetworks_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Networks.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestNetworks_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Networks.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestNetworks_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
network, err := c.Networks.Create(context.Background(), api.NetworkRequest{
Description: ptr("TestNetwork"),
Name: "Test",
})
assert.NoError(t, err)
assert.Equal(t, "Test", network.Name)
networks, err := c.Networks.List(context.Background())
assert.NoError(t, err)
assert.Empty(t, networks)
network, err = c.Networks.Update(context.Background(), "TestID", api.NetworkRequest{
Description: ptr("TestNetwork?"),
Name: "Test",
})
assert.NoError(t, err)
assert.Equal(t, "TestNetwork?", *network.Description)
err = c.Networks.Delete(context.Background(), "TestID")
assert.NoError(t, err)
})
}
func TestNetworkResources_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.NetworkResource{testNetworkResource})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testNetworkResource, ret[0])
})
}
func TestNetworkResources_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworkResources_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testNetworkResource)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testNetworkResource, *ret)
})
}
func TestNetworkResources_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworkResources_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiNetworksNetworkIdResourcesJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testNetworkResource)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").Create(context.Background(), api.PostApiNetworksNetworkIdResourcesJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testNetworkResource, *ret)
})
}
func TestNetworkResources_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").Create(context.Background(), api.PostApiNetworksNetworkIdResourcesJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestNetworkResources_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testNetworkResource)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").Update(context.Background(), "Test", api.PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testNetworkResource, *ret)
})
}
func TestNetworkResources_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Resources("Meow").Update(context.Background(), "Test", api.PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestNetworkResources_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Networks.Resources("Meow").Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestNetworkResources_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Networks.Resources("Meow").Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestNetworkResources_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
_, err := c.Networks.Resources("TestNetwork").Create(context.Background(), api.NetworkResourceRequest{
Address: "test.com",
Description: ptr("Description"),
Enabled: false,
Groups: []string{"test"},
Name: "test",
})
assert.NoError(t, err)
_, err = c.Networks.Resources("TestNetwork").List(context.Background())
assert.NoError(t, err)
_, err = c.Networks.Resources("TestNetwork").Get(context.Background(), "TestResource")
assert.NoError(t, err)
_, err = c.Networks.Resources("TestNetwork").Update(context.Background(), "TestResource", api.NetworkResourceRequest{
Address: "testnt.com",
})
assert.NoError(t, err)
err = c.Networks.Resources("TestNetwork").Delete(context.Background(), "TestResource")
assert.NoError(t, err)
})
}
func TestNetworkRouters_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.NetworkRouter{testNetworkRouter})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testNetworkRouter, ret[0])
})
}
func TestNetworkRouters_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworkRouters_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testNetworkRouter)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testNetworkRouter, *ret)
})
}
func TestNetworkRouters_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestNetworkRouters_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiNetworksNetworkIdRoutersJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "test", *req.Peer)
retBytes, _ := json.Marshal(testNetworkRouter)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").Create(context.Background(), api.PostApiNetworksNetworkIdRoutersJSONRequestBody{
Peer: ptr("test"),
})
require.NoError(t, err)
assert.Equal(t, testNetworkRouter, *ret)
})
}
func TestNetworkRouters_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").Create(context.Background(), api.PostApiNetworksNetworkIdRoutersJSONRequestBody{
Peer: ptr("test"),
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestNetworkRouters_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "test", *req.Peer)
retBytes, _ := json.Marshal(testNetworkRouter)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").Update(context.Background(), "Test", api.PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody{
Peer: ptr("test"),
})
require.NoError(t, err)
assert.Equal(t, testNetworkRouter, *ret)
})
}
func TestNetworkRouters_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Networks.Routers("Meow").Update(context.Background(), "Test", api.PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody{
Peer: ptr("test"),
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestNetworkRouters_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Networks.Routers("Meow").Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestNetworkRouters_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Networks.Routers("Meow").Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestNetworkRouters_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
_, err := c.Networks.Routers("TestNetwork").Create(context.Background(), api.NetworkRouterRequest{
Enabled: false,
Masquerade: false,
Metric: 9999,
PeerGroups: ptr([]string{"test"}),
})
assert.NoError(t, err)
_, err = c.Networks.Routers("TestNetwork").List(context.Background())
assert.NoError(t, err)
_, err = c.Networks.Routers("TestNetwork").Get(context.Background(), "TestRouter")
assert.NoError(t, err)
_, err = c.Networks.Routers("TestNetwork").Update(context.Background(), "TestRouter", api.NetworkRouterRequest{
Enabled: true,
})
assert.NoError(t, err)
err = c.Networks.Routers("TestNetwork").Delete(context.Background(), "TestRouter")
assert.NoError(t, err)
})
}

View File

@@ -0,0 +1,44 @@
package rest
import "net/http"
// option modifier for creation of Client
type option func(*Client)
// HTTPClient interface for HTTP client
type HttpClient interface {
Do(req *http.Request) (*http.Response, error)
}
// WithHTTPClient overrides HTTPClient used
func WithHttpClient(client HttpClient) option {
return func(c *Client) {
c.httpClient = client
}
}
// WithBearerToken uses provided bearer token acquired from SSO for authentication
func WithBearerToken(token string) option {
return WithAuthHeader("Bearer " + token)
}
// WithPAT uses provided Personal Access Token
// (created from NetBird Management Dashboard) for authentication
func WithPAT(token string) option {
return WithAuthHeader("Token " + token)
}
// WithManagementURL overrides target NetBird Management server
func WithManagementURL(url string) option {
return func(c *Client) {
c.managementURL = url
}
}
// WithAuthHeader overrides auth header completely, this should generally not be used
// and WithBearerToken or WithPAT should be used instead
func WithAuthHeader(value string) option {
return func(c *Client) {
c.authHeader = value
}
}

View File

@@ -0,0 +1,108 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// PeersAPI APIs for peers, do not use directly
type PeersAPI struct {
c *Client
}
// PeersListOption options for Peers List API
type PeersListOption func() (string, string)
func PeerNameFilter(name string) PeersListOption {
return func() (string, string) {
return "name", name
}
}
func PeerIPFilter(ip string) PeersListOption {
return func() (string, string) {
return "ip", ip
}
}
// List list all peers
// See more: https://docs.netbird.io/api/resources/peers#list-all-peers
func (a *PeersAPI) List(ctx context.Context, opts ...PeersListOption) ([]api.Peer, error) {
query := make(map[string]string)
for _, o := range opts {
k, v := o()
query[k] = v
}
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers", nil, query)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Peer](resp)
return ret, err
}
// Get retrieve a peer
// See more: https://docs.netbird.io/api/resources/peers#retrieve-a-peer
func (a *PeersAPI) Get(ctx context.Context, peerID string) (*api.Peer, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Peer](resp)
return &ret, err
}
// Update update information for a peer
// See more: https://docs.netbird.io/api/resources/peers#update-a-peer
func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApiPeersPeerIdJSONRequestBody) (*api.Peer, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/peers/"+peerID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Peer](resp)
return &ret, err
}
// Delete delete a peer
// See more: https://docs.netbird.io/api/resources/peers#delete-a-peer
func (a *PeersAPI) Delete(ctx context.Context, peerID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+peerID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// ListAccessiblePeers list all peers that the specified peer can connect to within the network
// See more: https://docs.netbird.io/api/resources/peers#list-accessible-peers
func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]api.Peer, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID+"/accessible-peers", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Peer](resp)
return ret, err
}

View File

@@ -0,0 +1,212 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testPeer = api.Peer{
ApprovalRequired: false,
Connected: false,
ConnectionIp: "127.0.0.1",
DnsLabel: "test",
Id: "Test",
}
)
func TestPeers_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Peer{testPeer})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testPeer, ret[0])
})
}
func TestPeers_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeers_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testPeer)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testPeer, *ret)
})
}
func TestPeers_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeers_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiPeersPeerIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, req.InactivityExpirationEnabled)
retBytes, _ := json.Marshal(testPeer)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Update(context.Background(), "Test", api.PutApiPeersPeerIdJSONRequestBody{
InactivityExpirationEnabled: true,
})
require.NoError(t, err)
assert.Equal(t, testPeer, *ret)
})
}
func TestPeers_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.Update(context.Background(), "Test", api.PutApiPeersPeerIdJSONRequestBody{
InactivityExpirationEnabled: false,
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPeers_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Peers.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestPeers_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Peers.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestPeers_ListAccessiblePeers_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Peer{testPeer})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.ListAccessiblePeers(context.Background(), "Test")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testPeer, ret[0])
})
}
func TestPeers_ListAccessiblePeers_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Peers.ListAccessiblePeers(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPeers_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
peers, err := c.Peers.List(context.Background())
require.NoError(t, err)
require.NotEmpty(t, peers)
filteredPeers, err := c.Peers.List(context.Background(), rest.PeerIPFilter("192.168.10.0"))
require.NoError(t, err)
require.Empty(t, filteredPeers)
peer, err := c.Peers.Get(context.Background(), peers[0].Id)
require.NoError(t, err)
assert.Equal(t, peers[0].Id, peer.Id)
peer, err = c.Peers.Update(context.Background(), peer.Id, api.PeerRequest{
LoginExpirationEnabled: true,
Name: "Test",
SshEnabled: false,
ApprovalRequired: ptr(false),
InactivityExpirationEnabled: false,
})
require.NoError(t, err)
assert.Equal(t, true, peer.LoginExpirationEnabled)
accessiblePeers, err := c.Peers.ListAccessiblePeers(context.Background(), peer.Id)
require.NoError(t, err)
assert.Empty(t, accessiblePeers)
err = c.Peers.Delete(context.Background(), peer.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,96 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// PoliciesAPI APIs for Policies, do not use directly
type PoliciesAPI struct {
c *Client
}
// List list all policies
// See more: https://docs.netbird.io/api/resources/policies#list-all-policies
func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) {
path := "/api/policies"
resp, err := a.c.NewRequest(ctx, "GET", path, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Policy](resp)
return ret, err
}
// Get get policy info
// See more: https://docs.netbird.io/api/resources/policies#retrieve-a-policy
func (a *PoliciesAPI) Get(ctx context.Context, policyID string) (*api.Policy, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/policies/"+policyID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Policy](resp)
return &ret, err
}
// Create create new policy
// See more: https://docs.netbird.io/api/resources/policies#create-a-policy
func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSONRequestBody) (*api.Policy, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/policies", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Policy](resp)
return &ret, err
}
// Update update policy info
// See more: https://docs.netbird.io/api/resources/policies#update-a-policy
func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.PutApiPoliciesPolicyIdJSONRequestBody) (*api.Policy, error) {
path := "/api/policies/" + policyID
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", path, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Policy](resp)
return &ret, err
}
// Delete delete policy
// See more: https://docs.netbird.io/api/resources/policies#delete-a-policy
func (a *PoliciesAPI) Delete(ctx context.Context, policyID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/policies/"+policyID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,241 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testPolicy = api.Policy{
Name: "wow",
Id: ptr("Test"),
Enabled: false,
}
)
func TestPolicies_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Policy{testPolicy})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testPolicy, ret[0])
})
}
func TestPolicies_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPolicies_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testPolicy)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testPolicy, *ret)
})
}
func TestPolicies_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPolicies_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiPoliciesPolicyIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testPolicy)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.Create(context.Background(), api.PostApiPoliciesJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testPolicy, *ret)
})
}
func TestPolicies_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.Create(context.Background(), api.PostApiPoliciesJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPolicies_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiPoliciesPolicyIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testPolicy)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.Update(context.Background(), "Test", api.PutApiPoliciesPolicyIdJSONRequestBody{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testPolicy, *ret)
})
}
func TestPolicies_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Policies.Update(context.Background(), "Test", api.PutApiPoliciesPolicyIdJSONRequestBody{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPolicies_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Policies.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestPolicies_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Policies.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestPolicies_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
policies, err := c.Policies.List(context.Background())
require.NoError(t, err)
require.NotEmpty(t, policies)
policy, err := c.Policies.Get(context.Background(), *policies[0].Id)
require.NoError(t, err)
assert.Equal(t, *policies[0].Id, *policy.Id)
policy, err = c.Policies.Update(context.Background(), *policy.Id, api.PolicyCreate{
Description: ptr("Test Policy"),
Enabled: false,
Name: "Test",
Rules: []api.PolicyRuleUpdate{
{
Action: api.PolicyRuleUpdateAction(policy.Rules[0].Action),
Bidirectional: true,
Description: ptr("Test Policy"),
Sources: ptr([]string{(*policy.Rules[0].Sources)[0].Id}),
Destinations: ptr([]string{(*policy.Rules[0].Destinations)[0].Id}),
Enabled: false,
Protocol: api.PolicyRuleUpdateProtocolAll,
},
},
SourcePostureChecks: nil,
})
require.NoError(t, err)
assert.Equal(t, "Test Policy", *policy.Rules[0].Description)
policy, err = c.Policies.Create(context.Background(), api.PolicyUpdate{
Description: ptr("Test Policy 2"),
Enabled: false,
Name: "Test",
Rules: []api.PolicyRuleUpdate{
{
Action: api.PolicyRuleUpdateAction(policy.Rules[0].Action),
Bidirectional: true,
Description: ptr("Test Policy 2"),
Sources: ptr([]string{(*policy.Rules[0].Sources)[0].Id}),
Destinations: ptr([]string{(*policy.Rules[0].Destinations)[0].Id}),
Enabled: false,
Protocol: api.PolicyRuleUpdateProtocolAll,
},
},
SourcePostureChecks: nil,
})
require.NoError(t, err)
err = c.Policies.Delete(context.Background(), *policy.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,92 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// PostureChecksAPI APIs for PostureChecks, do not use directly
type PostureChecksAPI struct {
c *Client
}
// List list all posture checks
// See more: https://docs.netbird.io/api/resources/posture-checks#list-all-posture-checks
func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.PostureCheck](resp)
return ret, err
}
// Get get posture check info
// See more: https://docs.netbird.io/api/resources/posture-checks#retrieve-a-posture-check
func (a *PostureChecksAPI) Get(ctx context.Context, postureCheckID string) (*api.PostureCheck, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks/"+postureCheckID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.PostureCheck](resp)
return &ret, err
}
// Create create new posture check
// See more: https://docs.netbird.io/api/resources/posture-checks#create-a-posture-check
func (a *PostureChecksAPI) Create(ctx context.Context, request api.PostApiPostureChecksJSONRequestBody) (*api.PostureCheck, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/posture-checks", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.PostureCheck](resp)
return &ret, err
}
// Update update posture check info
// See more: https://docs.netbird.io/api/resources/posture-checks#update-a-posture-check
func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, request api.PutApiPostureChecksPostureCheckIdJSONRequestBody) (*api.PostureCheck, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/posture-checks/"+postureCheckID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.PostureCheck](resp)
return &ret, err
}
// Delete delete posture check
// See more: https://docs.netbird.io/api/resources/posture-checks#delete-a-posture-check
func (a *PostureChecksAPI) Delete(ctx context.Context, postureCheckID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/posture-checks/"+postureCheckID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,233 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testPostureCheck = api.PostureCheck{
Id: "Test",
Name: "wow",
}
)
func TestPostureChecks_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.PostureCheck{testPostureCheck})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testPostureCheck, ret[0])
})
}
func TestPostureChecks_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPostureChecks_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testPostureCheck)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testPostureCheck, *ret)
})
}
func TestPostureChecks_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestPostureChecks_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostureCheckUpdate
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testPostureCheck)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testPostureCheck, *ret)
})
}
func TestPostureChecks_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPostureChecks_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostureCheckUpdate
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "weaw", req.Name)
retBytes, _ := json.Marshal(testPostureCheck)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.Update(context.Background(), "Test", api.PostureCheckUpdate{
Name: "weaw",
})
require.NoError(t, err)
assert.Equal(t, testPostureCheck, *ret)
})
}
func TestPostureChecks_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.PostureChecks.Update(context.Background(), "Test", api.PostureCheckUpdate{
Name: "weaw",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestPostureChecks_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.PostureChecks.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestPostureChecks_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.PostureChecks.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestPostureChecks_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
check, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{
Name: "Test",
Description: "Testing",
Checks: &api.Checks{
OsVersionCheck: &api.OSVersionCheck{
Windows: &api.MinKernelVersionCheck{
MinKernelVersion: "0.0.0",
},
},
},
})
require.NoError(t, err)
assert.Equal(t, "Test", check.Name)
checks, err := c.PostureChecks.List(context.Background())
require.NoError(t, err)
assert.Len(t, checks, 1)
check, err = c.PostureChecks.Update(context.Background(), check.Id, api.PostureCheckUpdate{
Name: "Tests",
Description: "Testings",
Checks: &api.Checks{
GeoLocationCheck: &api.GeoLocationCheck{
Action: api.GeoLocationCheckActionAllow, Locations: []api.Location{
{
CityName: ptr("Cairo"),
CountryCode: "EG",
},
},
},
},
})
require.NoError(t, err)
assert.Equal(t, "Testings", *check.Description)
check, err = c.PostureChecks.Get(context.Background(), check.Id)
require.NoError(t, err)
assert.Equal(t, "Tests", check.Name)
err = c.PostureChecks.Delete(context.Background(), check.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,92 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// RoutesAPI APIs for Routes, do not use directly
type RoutesAPI struct {
c *Client
}
// List list all routes
// See more: https://docs.netbird.io/api/resources/routes#list-all-routes
func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/routes", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Route](resp)
return ret, err
}
// Get get route info
// See more: https://docs.netbird.io/api/resources/routes#retrieve-a-route
func (a *RoutesAPI) Get(ctx context.Context, routeID string) (*api.Route, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/routes/"+routeID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Route](resp)
return &ret, err
}
// Create create new route
// See more: https://docs.netbird.io/api/resources/routes#create-a-route
func (a *RoutesAPI) Create(ctx context.Context, request api.PostApiRoutesJSONRequestBody) (*api.Route, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/routes", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Route](resp)
return &ret, err
}
// Update update route info
// See more: https://docs.netbird.io/api/resources/routes#update-a-route
func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutApiRoutesRouteIdJSONRequestBody) (*api.Route, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/routes/"+routeID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Route](resp)
return &ret, err
}
// Delete delete route
// See more: https://docs.netbird.io/api/resources/routes#delete-a-route
func (a *RoutesAPI) Delete(ctx context.Context, routeID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/routes/"+routeID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,231 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testRoute = api.Route{
Id: "Test",
Domains: ptr([]string{"google.com"}),
}
)
func TestRoutes_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.Route{testRoute})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testRoute, ret[0])
})
}
func TestRoutes_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestRoutes_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testRoute)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testRoute, *ret)
})
}
func TestRoutes_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestRoutes_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiRoutesJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "meow", req.Description)
retBytes, _ := json.Marshal(testRoute)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.Create(context.Background(), api.PostApiRoutesJSONRequestBody{
Description: "meow",
})
require.NoError(t, err)
assert.Equal(t, testRoute, *ret)
})
}
func TestRoutes_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.Create(context.Background(), api.PostApiRoutesJSONRequestBody{
Description: "meow",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestRoutes_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiRoutesRouteIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "meow", req.Description)
retBytes, _ := json.Marshal(testRoute)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.Update(context.Background(), "Test", api.PutApiRoutesRouteIdJSONRequestBody{
Description: "meow",
})
require.NoError(t, err)
assert.Equal(t, testRoute, *ret)
})
}
func TestRoutes_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Routes.Update(context.Background(), "Test", api.PutApiRoutesRouteIdJSONRequestBody{
Description: "meow",
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestRoutes_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Routes.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestRoutes_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Routes.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestRoutes_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
route, err := c.Routes.Create(context.Background(), api.RouteRequest{
Description: "Meow",
Enabled: false,
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
PeerGroups: ptr([]string{"cs1tnh0hhcjnqoiuebeg"}),
Domains: ptr([]string{"google.com"}),
Masquerade: true,
Metric: 9999,
KeepRoute: false,
NetworkId: "Test",
})
require.NoError(t, err)
assert.Equal(t, "Test", route.NetworkId)
routes, err := c.Routes.List(context.Background())
require.NoError(t, err)
assert.Len(t, routes, 1)
route, err = c.Routes.Update(context.Background(), route.Id, api.RouteRequest{
Description: "Testings",
Enabled: false,
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
PeerGroups: ptr([]string{"cs1tnh0hhcjnqoiuebeg"}),
Domains: ptr([]string{"google.com"}),
Masquerade: true,
Metric: 9999,
KeepRoute: false,
NetworkId: "Tests",
})
require.NoError(t, err)
assert.Equal(t, "Testings", route.Description)
route, err = c.Routes.Get(context.Background(), route.Id)
require.NoError(t, err)
assert.Equal(t, "Tests", route.NetworkId)
err = c.Routes.Delete(context.Background(), route.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,94 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// SetupKeysAPI APIs for Setup keys, do not use directly
type SetupKeysAPI struct {
c *Client
}
// List list all setup keys
// See more: https://docs.netbird.io/api/resources/setup-keys#list-all-setup-keys
func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.SetupKey](resp)
return ret, err
}
// Get get setup key info
// See more: https://docs.netbird.io/api/resources/setup-keys#retrieve-a-setup-key
func (a *SetupKeysAPI) Get(ctx context.Context, setupKeyID string) (*api.SetupKey, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys/"+setupKeyID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.SetupKey](resp)
return &ret, err
}
// Create generate new Setup Key
// See more: https://docs.netbird.io/api/resources/setup-keys#create-a-setup-key
func (a *SetupKeysAPI) Create(ctx context.Context, request api.PostApiSetupKeysJSONRequestBody) (*api.SetupKeyClear, error) {
path := "/api/setup-keys"
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", path, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.SetupKeyClear](resp)
return &ret, err
}
// Update generate new Setup Key
// See more: https://docs.netbird.io/api/resources/setup-keys#update-a-setup-key
func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request api.PutApiSetupKeysKeyIdJSONRequestBody) (*api.SetupKey, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/setup-keys/"+setupKeyID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.SetupKey](resp)
return &ret, err
}
// Delete delete setup key
// See more: https://docs.netbird.io/api/resources/setup-keys#delete-a-setup-key
func (a *SetupKeysAPI) Delete(ctx context.Context, setupKeyID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/setup-keys/"+setupKeyID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,232 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testSetupKey = api.SetupKey{
Id: "Test",
Name: "wow",
AutoGroups: []string{"meow"},
Ephemeral: true,
}
testSteupKeyGenerated = api.SetupKeyClear{
Id: "Test",
Name: "wow",
AutoGroups: []string{"meow"},
Ephemeral: true,
Key: "shhh",
}
)
func TestSetupKeys_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.SetupKey{testSetupKey})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testSetupKey, ret[0])
})
}
func TestSetupKeys_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestSetupKeys_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testSetupKey)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.Get(context.Background(), "Test")
require.NoError(t, err)
assert.Equal(t, testSetupKey, *ret)
})
}
func TestSetupKeys_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.Get(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestSetupKeys_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiSetupKeysJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, 5, req.ExpiresIn)
retBytes, _ := json.Marshal(testSteupKeyGenerated)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.Create(context.Background(), api.PostApiSetupKeysJSONRequestBody{
ExpiresIn: 5,
})
require.NoError(t, err)
assert.Equal(t, testSteupKeyGenerated, *ret)
})
}
func TestSetupKeys_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.Create(context.Background(), api.PostApiSetupKeysJSONRequestBody{
ExpiresIn: 5,
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestSetupKeys_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiSetupKeysKeyIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, req.Revoked)
retBytes, _ := json.Marshal(testSetupKey)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.Update(context.Background(), "Test", api.PutApiSetupKeysKeyIdJSONRequestBody{
Revoked: true,
})
require.NoError(t, err)
assert.Equal(t, testSetupKey, *ret)
})
}
func TestSetupKeys_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.SetupKeys.Update(context.Background(), "Test", api.PutApiSetupKeysKeyIdJSONRequestBody{
Revoked: true,
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestSetupKeys_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.SetupKeys.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestSetupKeys_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.SetupKeys.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestSetupKeys_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
group, err := c.Groups.Create(context.Background(), api.GroupRequest{
Name: "Test",
})
require.NoError(t, err)
skClear, err := c.SetupKeys.Create(context.Background(), api.CreateSetupKeyRequest{
AutoGroups: []string{group.Id},
Ephemeral: ptr(false),
Name: "test",
Type: "reusable",
})
require.NoError(t, err)
assert.Equal(t, true, skClear.Valid)
keys, err := c.SetupKeys.List(context.Background())
require.NoError(t, err)
assert.Len(t, keys, 2)
sk, err := c.SetupKeys.Update(context.Background(), skClear.Id, api.SetupKeyRequest{
Revoked: true,
AutoGroups: []string{group.Id},
})
require.NoError(t, err)
sk, err = c.SetupKeys.Get(context.Background(), sk.Id)
require.NoError(t, err)
assert.Equal(t, false, sk.Valid)
err = c.SetupKeys.Delete(context.Background(), sk.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,74 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// TokensAPI APIs for PATs, do not use directly
type TokensAPI struct {
c *Client
}
// List list user tokens
// See more: https://docs.netbird.io/api/resources/tokens#list-all-tokens
func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAccessToken, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.PersonalAccessToken](resp)
return ret, err
}
// Get get user token info
// See more: https://docs.netbird.io/api/resources/tokens#retrieve-a-token
func (a *TokensAPI) Get(ctx context.Context, userID, tokenID string) (*api.PersonalAccessToken, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens/"+tokenID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.PersonalAccessToken](resp)
return &ret, err
}
// Create generate new PAT for user
// See more: https://docs.netbird.io/api/resources/tokens#create-a-token
func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostApiUsersUserIdTokensJSONRequestBody) (*api.PersonalAccessTokenGenerated, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/tokens", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.PersonalAccessTokenGenerated](resp)
return &ret, err
}
// Delete delete user token
// See more: https://docs.netbird.io/api/resources/tokens#delete-a-token
func (a *TokensAPI) Delete(ctx context.Context, userID, tokenID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID+"/tokens/"+tokenID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,180 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testToken = api.PersonalAccessToken{
Id: "Test",
CreatedAt: time.Time{},
CreatedBy: "meow",
ExpirationDate: time.Time{},
LastUsed: nil,
Name: "wow",
}
testTokenGenerated = api.PersonalAccessTokenGenerated{
PersonalAccessToken: testToken,
PlainToken: "shhh",
}
)
func TestTokens_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.PersonalAccessToken{testToken})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Tokens.List(context.Background(), "meow")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testToken, ret[0])
})
}
func TestTokens_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Tokens.List(context.Background(), "meow")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestTokens_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testToken)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Tokens.Get(context.Background(), "meow", "Test")
require.NoError(t, err)
assert.Equal(t, testToken, *ret)
})
}
func TestTokens_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Tokens.Get(context.Background(), "meow", "Test")
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestTokens_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiUsersUserIdTokensJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, 5, req.ExpiresIn)
retBytes, _ := json.Marshal(testTokenGenerated)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Tokens.Create(context.Background(), "meow", api.PostApiUsersUserIdTokensJSONRequestBody{
ExpiresIn: 5,
})
require.NoError(t, err)
assert.Equal(t, testTokenGenerated, *ret)
})
}
func TestTokens_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Tokens.Create(context.Background(), "meow", api.PostApiUsersUserIdTokensJSONRequestBody{
ExpiresIn: 5,
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestTokens_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Tokens.Delete(context.Background(), "meow", "Test")
require.NoError(t, err)
})
}
func TestTokens_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Tokens.Delete(context.Background(), "meow", "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestTokens_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
tokenClear, err := c.Tokens.Create(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", api.PersonalAccessTokenRequest{
Name: "Test",
ExpiresIn: 365,
})
require.NoError(t, err)
assert.Equal(t, "Test", tokenClear.PersonalAccessToken.Name)
tokens, err := c.Tokens.List(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003")
require.NoError(t, err)
assert.Len(t, tokens, 2)
token, err := c.Tokens.Get(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", tokenClear.PersonalAccessToken.Id)
require.NoError(t, err)
assert.Equal(t, "Test", token.Name)
err = c.Tokens.Delete(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", token.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,107 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/management/server/http/api"
)
// UsersAPI APIs for users, do not use directly
type UsersAPI struct {
c *Client
}
// List list all users, only returns one user always
// See more: https://docs.netbird.io/api/resources/users#list-all-users
func (a *UsersAPI) List(ctx context.Context) ([]api.User, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/users", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.User](resp)
return ret, err
}
// Create create user
// See more: https://docs.netbird.io/api/resources/users#create-a-user
func (a *UsersAPI) Create(ctx context.Context, request api.PostApiUsersJSONRequestBody) (*api.User, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/users", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.User](resp)
return &ret, err
}
// Update update user settings
// See more: https://docs.netbird.io/api/resources/users#update-a-user
func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApiUsersUserIdJSONRequestBody) (*api.User, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/users/"+userID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.User](resp)
return &ret, err
}
// Delete delete user
// See more: https://docs.netbird.io/api/resources/users#delete-a-user
func (a *UsersAPI) Delete(ctx context.Context, userID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// ResendInvitation resend user invitation
// See more: https://docs.netbird.io/api/resources/users#resend-user-invitation
func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error {
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/invite", nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// Current gets the current user info
// See more: https://docs.netbird.io/api/resources/users#retrieve-current-user
func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/current", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.User](resp)
return &ret, err
}

View File

@@ -0,0 +1,258 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
)
var (
testUser = api.User{
Id: "Test",
AutoGroups: []string{"test-group"},
Email: "test@test.com",
IsBlocked: false,
IsCurrent: ptr(false),
IsServiceUser: ptr(false),
Issued: ptr("api"),
LastLogin: &time.Time{},
Name: "M. Essam",
Role: "user",
Status: api.UserStatusActive,
}
)
func TestUsers_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal([]api.User{testUser})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.List(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testUser, ret[0])
})
}
func TestUsers_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.List(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestUsers_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiUsersJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, []string{"meow"}, req.AutoGroups)
retBytes, _ := json.Marshal(testUser)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Create(context.Background(), api.PostApiUsersJSONRequestBody{
AutoGroups: []string{"meow"},
})
require.NoError(t, err)
assert.Equal(t, testUser, *ret)
})
}
func TestUsers_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Create(context.Background(), api.PostApiUsersJSONRequestBody{
AutoGroups: []string{"meow"},
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestUsers_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiUsersUserIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, true, req.IsBlocked)
retBytes, _ := json.Marshal(testUser)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Update(context.Background(), "Test", api.PutApiUsersUserIdJSONRequestBody{
IsBlocked: true,
})
require.NoError(t, err)
assert.Equal(t, testUser, *ret)
})
}
func TestUsers_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Update(context.Background(), "Test", api.PutApiUsersUserIdJSONRequestBody{
IsBlocked: true,
})
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Nil(t, ret)
})
}
func TestUsers_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.Users.Delete(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestUsers_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Users.Delete(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestUsers_ResendInvitation_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
w.WriteHeader(200)
})
err := c.Users.ResendInvitation(context.Background(), "Test")
require.NoError(t, err)
})
}
func TestUsers_ResendInvitation_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.Users.ResendInvitation(context.Background(), "Test")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestUsers_Current_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/current", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(testUser)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Current(context.Background())
require.NoError(t, err)
assert.Equal(t, testUser, *ret)
})
}
func TestUsers_Current_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/users/current", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.Users.Current(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestUsers_Integration(t *testing.T) {
withBlackBoxServer(t, func(c *rest.Client) {
// rest client PAT is owner's
current, err := c.Users.Current(context.Background())
require.NoError(t, err)
assert.Equal(t, "a23efe53-63fb-11ec-90d6-0242ac120003", current.Id)
assert.Equal(t, "owner", current.Role)
user, err := c.Users.Create(context.Background(), api.UserCreateRequest{
AutoGroups: []string{},
Email: ptr("test@example.com"),
IsServiceUser: true,
Name: ptr("Nobody"),
Role: "user",
})
require.NoError(t, err)
assert.Equal(t, "Nobody", user.Name)
users, err := c.Users.List(context.Background())
require.NoError(t, err)
assert.NotEmpty(t, users)
user, err = c.Users.Update(context.Background(), user.Id, api.UserRequest{
AutoGroups: []string{},
Role: "admin",
})
require.NoError(t, err)
assert.Equal(t, "admin", user.Role)
err = c.Users.Delete(context.Background(), user.Id)
require.NoError(t, err)
})
}

View File

@@ -0,0 +1,45 @@
package domain
import (
"strings"
"golang.org/x/net/idna"
)
// Domain represents a punycode-encoded domain string.
// This should only be converted from a string when the string already is in punycode, otherwise use FromString.
type Domain string
// String converts the Domain to a non-punycode string.
// For an infallible conversion, use SafeString.
func (d Domain) String() (string, error) {
unicode, err := idna.ToUnicode(string(d))
if err != nil {
return "", err
}
return unicode, nil
}
// SafeString converts the Domain to a non-punycode string, falling back to the punycode string if conversion fails.
func (d Domain) SafeString() string {
str, err := d.String()
if err != nil {
return string(d)
}
return str
}
// PunycodeString returns the punycode representation of the Domain.
// This should only be used if a punycode domain is expected but only a string is supported.
func (d Domain) PunycodeString() string {
return string(d)
}
// FromString creates a Domain from a string, converting it to punycode.
func FromString(s string) (Domain, error) {
ascii, err := idna.ToASCII(s)
if err != nil {
return "", err
}
return Domain(strings.ToLower(ascii)), nil
}

View File

@@ -0,0 +1 @@
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=

View File

@@ -0,0 +1,108 @@
package domain
import (
"sort"
"strings"
)
// List is a slice of punycode-encoded domain strings.
type List []Domain
// ToStringList converts a List to a slice of string.
func (d List) ToStringList() ([]string, error) {
var list []string
for _, domain := range d {
s, err := domain.String()
if err != nil {
return nil, err
}
list = append(list, s)
}
return list, nil
}
// ToPunycodeList converts the List to a slice of Punycode-encoded domain strings.
func (d List) ToPunycodeList() []string {
var list []string
for _, domain := range d {
list = append(list, string(domain))
}
return list
}
// ToSafeStringList converts the List to a slice of non-punycode strings.
// If a domain cannot be converted, the original string is used.
func (d List) ToSafeStringList() []string {
var list []string
for _, domain := range d {
list = append(list, domain.SafeString())
}
return list
}
// String converts List to a comma-separated string.
func (d List) String() (string, error) {
list, err := d.ToStringList()
if err != nil {
return "", err
}
return strings.Join(list, ", "), nil
}
// SafeString converts List to a comma-separated non-punycode string.
// If a domain cannot be converted, the original string is used.
func (d List) SafeString() string {
str, err := d.String()
if err != nil {
return d.PunycodeString()
}
return str
}
// PunycodeString converts the List to a comma-separated string of Punycode-encoded domains.
func (d List) PunycodeString() string {
return strings.Join(d.ToPunycodeList(), ", ")
}
func (d List) Equal(domains List) bool {
if len(d) != len(domains) {
return false
}
sort.Slice(d, func(i, j int) bool {
return d[i] < d[j]
})
sort.Slice(domains, func(i, j int) bool {
return domains[i] < domains[j]
})
for i, domain := range d {
if domain != domains[i] {
return false
}
}
return true
}
// FromStringList creates a DomainList from a slice of string.
func FromStringList(s []string) (List, error) {
var dl List
for _, domain := range s {
d, err := FromString(domain)
if err != nil {
return nil, err
}
dl = append(dl, d)
}
return dl, nil
}
// FromPunycodeList creates a List from a slice of Punycode-encoded domain strings.
func FromPunycodeList(s []string) List {
var dl List
for _, domain := range s {
dl = append(dl, Domain(strings.ToLower(domain)))
}
return dl
}

View File

@@ -0,0 +1,49 @@
package domain
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_EqualReturnsTrueForIdenticalLists(t *testing.T) {
list1 := List{"domain1", "domain2", "domain3"}
list2 := List{"domain1", "domain2", "domain3"}
assert.True(t, list1.Equal(list2))
}
func Test_EqualReturnsFalseForDifferentLengths(t *testing.T) {
list1 := List{"domain1", "domain2"}
list2 := List{"domain1", "domain2", "domain3"}
assert.False(t, list1.Equal(list2))
}
func Test_EqualReturnsFalseForDifferentElements(t *testing.T) {
list1 := List{"domain1", "domain2", "domain3"}
list2 := List{"domain1", "domain4", "domain3"}
assert.False(t, list1.Equal(list2))
}
func Test_EqualReturnsTrueForUnsortedIdenticalLists(t *testing.T) {
list1 := List{"domain3", "domain1", "domain2"}
list2 := List{"domain1", "domain2", "domain3"}
assert.True(t, list1.Equal(list2))
}
func Test_EqualReturnsFalseForEmptyAndNonEmptyList(t *testing.T) {
list1 := List{}
list2 := List{"domain1"}
assert.False(t, list1.Equal(list2))
}
func Test_EqualReturnsTrueForBothEmptyLists(t *testing.T) {
list1 := List{}
list2 := List{}
assert.True(t, list1.Equal(list2))
}

View File

@@ -0,0 +1,56 @@
package domain
import (
"fmt"
"regexp"
"strings"
)
const maxDomains = 32
var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
func ValidateDomains(domains []string) (List, error) {
if len(domains) == 0 {
return nil, fmt.Errorf("domains list is empty")
}
if len(domains) > maxDomains {
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
}
var domainList List
for _, d := range domains {
// handles length and idna conversion
punycode, err := FromString(d)
if err != nil {
return domainList, fmt.Errorf("convert domain to punycode: %s: %w", d, err)
}
if !domainRegex.MatchString(string(punycode)) {
return domainList, fmt.Errorf("invalid domain format: %s", d)
}
domainList = append(domainList, punycode)
}
return domainList, nil
}
// ValidateDomainsList checks if each domain in the list is valid
func ValidateDomainsList(domains []string) error {
if len(domains) == 0 {
return nil
}
if len(domains) > maxDomains {
return fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
}
for _, d := range domains {
d := strings.ToLower(d)
if !domainRegex.MatchString(d) {
return fmt.Errorf("invalid domain format: %s", d)
}
}
return nil
}

View File

@@ -0,0 +1,185 @@
package domain
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestValidateDomains(t *testing.T) {
tests := []struct {
name string
domains []string
expected List
wantErr bool
}{
{
name: "Empty list",
domains: nil,
expected: nil,
wantErr: true,
},
{
name: "Valid ASCII domain",
domains: []string{"sub.ex-ample.com"},
expected: List{"sub.ex-ample.com"},
wantErr: false,
},
{
name: "Valid Unicode domain",
domains: []string{"münchen.de"},
expected: List{"xn--mnchen-3ya.de"},
wantErr: false,
},
{
name: "Valid Unicode, all labels",
domains: []string{"中国.中国.中国"},
expected: List{"xn--fiqs8s.xn--fiqs8s.xn--fiqs8s"},
wantErr: false,
},
{
name: "With underscores",
domains: []string{"_jabber._tcp.gmail.com"},
expected: List{"_jabber._tcp.gmail.com"},
wantErr: false,
},
{
name: "Invalid domain format",
domains: []string{"-example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid domain format 2",
domains: []string{"example.com-"},
expected: nil,
wantErr: true,
},
{
name: "Multiple domains valid and invalid",
domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"},
expected: List{"google.com"},
wantErr: true,
},
{
name: "Valid wildcard domain",
domains: []string{"*.example.com"},
expected: List{"*.example.com"},
wantErr: false,
},
{
name: "Wildcard with dot domain",
domains: []string{".*.example.com"},
expected: nil,
wantErr: true,
},
{
name: "Wildcard with dot domain",
domains: []string{".*.example.com"},
expected: nil,
wantErr: true,
},
{
name: "Invalid wildcard domain",
domains: []string{"a.*.example.com"},
expected: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ValidateDomains(tt.domains)
assert.Equal(t, tt.wantErr, err != nil)
assert.Equal(t, got, tt.expected)
})
}
}
func TestValidateDomainsList(t *testing.T) {
validDomains := make([]string, maxDomains)
for i := range maxDomains {
validDomains[i] = fmt.Sprintf("example%d.com", i)
}
tests := []struct {
name string
domains []string
wantErr bool
}{
{
name: "Empty list",
domains: nil,
wantErr: false,
},
{
name: "Single valid ASCII domain",
domains: []string{"sub.ex-ample.com"},
wantErr: false,
},
{
name: "Underscores in labels",
domains: []string{"_jabber._tcp.gmail.com"},
wantErr: false,
},
{
// Unlike ValidateDomains (which converts to punycode),
// ValidateDomainsStrSlice will fail on non-ASCII domain chars.
name: "Unicode domain fails (no punycode conversion)",
domains: []string{"münchen.de"},
wantErr: true,
},
{
name: "Invalid domain format - leading dash",
domains: []string{"-example.com"},
wantErr: true,
},
{
name: "Invalid domain format - trailing dash",
domains: []string{"example-.com"},
wantErr: true,
},
{
name: "Multiple domains with a valid one, then invalid",
domains: []string{"google.com", "invalid_domain.com-"},
wantErr: true,
},
{
name: "Valid wildcard domain",
domains: []string{"*.example.com"},
wantErr: false,
},
{
name: "Wildcard with leading dot - invalid",
domains: []string{".*.example.com"},
wantErr: true,
},
{
name: "Invalid wildcard with multiple asterisks",
domains: []string{"a.*.example.com"},
wantErr: true,
},
{
name: "Exactly maxDomains items (valid)",
domains: validDomains,
wantErr: false,
},
{
name: "Exceeds maxDomains items",
domains: append(validDomains, "extra.com"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateDomainsList(tt.domains)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -0,0 +1,17 @@
#!/bin/bash
set -e
if ! which realpath > /dev/null 2>&1
then
echo realpath is not installed
echo run: brew install coreutils
exit 1
fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./management.proto --go_out=../ --go-grpc_out=../
cd "$old_pwd"

View File

@@ -0,0 +1,2 @@
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,542 @@
syntax = "proto3";
import "google/protobuf/timestamp.proto";
import "google/protobuf/duration.proto";
option go_package = "/proto";
package management;
service ManagementService {
// Login logs in peer. In case server returns codes.PermissionDenied this endpoint can be used to register Peer providing LoginRequest.setupKey
// Returns encrypted LoginResponse in EncryptedMessage.Body
rpc Login(EncryptedMessage) returns (EncryptedMessage) {}
// Sync enables peer synchronization. Each peer that is connected to this stream will receive updates from the server.
// For example, if a new peer has been added to an account all other connected peers will receive this peer's Wireguard public key as an update
// The initial SyncResponse contains all of the available peers so the local state can be refreshed
// Returns encrypted SyncResponse in EncryptedMessage.Body
rpc Sync(EncryptedMessage) returns (stream EncryptedMessage) {}
// Exposes a Wireguard public key of the Management service.
// This key is used to support message encryption between client and server
rpc GetServerKey(Empty) returns (ServerKeyResponse) {}
// health check endpoint
rpc isHealthy(Empty) returns (Empty) {}
// Exposes a device authorization flow information
// This is used for initiating a Oauth 2 device authorization grant flow
// which will be used by our clients to Login.
// EncryptedMessage of the request has a body of DeviceAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of DeviceAuthorizationFlow.
rpc GetDeviceAuthorizationFlow(EncryptedMessage) returns (EncryptedMessage) {}
// Exposes a PKCE authorization code flow information
// This is used for initiating a Oauth 2 authorization grant flow
// with Proof Key for Code Exchange (PKCE) which will be used by our clients to Login.
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
rpc GetPKCEAuthorizationFlow(EncryptedMessage) returns (EncryptedMessage) {}
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
rpc SyncMeta(EncryptedMessage) returns (Empty) {}
// Logout logs out the peer and removes it from the management server
rpc Logout(EncryptedMessage) returns (Empty) {}
}
message EncryptedMessage {
// Wireguard public key
string wgPubKey = 1;
// encrypted message Body
bytes body = 2;
// Version of the Netbird Management Service protocol
int32 version = 3;
}
message SyncRequest {
// Meta data of the peer
PeerSystemMeta meta = 1;
}
// SyncResponse represents a state that should be applied to the local peer (e.g. Netbird servers config as well as local peer and remote peers configs)
message SyncResponse {
// Global config
NetbirdConfig netbirdConfig = 1;
// Deprecated. Use NetworkMap.PeerConfig
PeerConfig peerConfig = 2;
// Deprecated. Use NetworkMap.RemotePeerConfig
repeated RemotePeerConfig remotePeers = 3;
// Indicates whether remotePeers array is empty or not to bypass protobuf null and empty array equality.
// Deprecated. Use NetworkMap.remotePeersIsEmpty
bool remotePeersIsEmpty = 4;
NetworkMap NetworkMap = 5;
// Posture checks to be evaluated by client
repeated Checks Checks = 6;
}
message SyncMetaRequest {
// Meta data of the peer
PeerSystemMeta meta = 1;
}
message LoginRequest {
// Pre-authorized setup key (can be empty)
string setupKey = 1;
// Meta data of the peer (e.g. name, os_name, os_version,
PeerSystemMeta meta = 2;
// SSO token (can be empty)
string jwtToken = 3;
// Can be absent for now.
PeerKeys peerKeys = 4;
repeated string dnsLabels = 5;
}
// PeerKeys is additional peer info like SSH pub key and WireGuard public key.
// This message is sent on Login or register requests, or when a key rotation has to happen.
message PeerKeys {
// sshPubKey represents a public SSH key of the peer. Can be absent.
bytes sshPubKey = 1;
// wgPubKey represents a public WireGuard key of the peer. Can be absent.
bytes wgPubKey = 2;
}
// Environment is part of the PeerSystemMeta and describes the environment the agent is running in.
message Environment {
// cloud is the cloud provider the agent is running in if applicable.
string cloud = 1;
// platform is the platform the agent is running on if applicable.
string platform = 2;
}
// File represents a file on the system.
message File {
// path is the path to the file.
string path = 1;
// exist indicate whether the file exists.
bool exist = 2;
// processIsRunning indicates whether the file is a running process or not.
bool processIsRunning = 3;
}
message Flags {
bool rosenpassEnabled = 1;
bool rosenpassPermissive = 2;
bool serverSSHAllowed = 3;
bool disableClientRoutes = 4;
bool disableServerRoutes = 5;
bool disableDNS = 6;
bool disableFirewall = 7;
bool blockLANAccess = 8;
bool blockInbound = 9;
bool lazyConnectionEnabled = 10;
}
// PeerSystemMeta is machine meta data like OS and version.
message PeerSystemMeta {
string hostname = 1;
string goOS = 2;
string kernel = 3;
string core = 4;
string platform = 5;
string OS = 6;
string netbirdVersion = 7;
string uiVersion = 8;
string kernelVersion = 9;
string OSVersion = 10;
repeated NetworkAddress networkAddresses = 11;
string sysSerialNumber = 12;
string sysProductName = 13;
string sysManufacturer = 14;
Environment environment = 15;
repeated File files = 16;
Flags flags = 17;
}
message LoginResponse {
// Global config
NetbirdConfig netbirdConfig = 1;
// Peer local config
PeerConfig peerConfig = 2;
// Posture checks to be evaluated by client
repeated Checks Checks = 3;
}
message ServerKeyResponse {
// Server's Wireguard public key
string key = 1;
// Key expiration timestamp after which the key should be fetched again by the client
google.protobuf.Timestamp expiresAt = 2;
// Version of the Netbird Management Service protocol
int32 version = 3;
}
message Empty {}
// NetbirdConfig is a common configuration of any Netbird peer. It contains STUN, TURN, Signal and Management servers configurations
message NetbirdConfig {
// a list of STUN servers
repeated HostConfig stuns = 1;
// a list of TURN servers
repeated ProtectedHostConfig turns = 2;
// a Signal server config
HostConfig signal = 3;
RelayConfig relay = 4;
FlowConfig flow = 5;
}
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
message HostConfig {
// URI of the resource e.g. turns://stun.netbird.io:4430 or signal.netbird.io:10000
string uri = 1;
Protocol protocol = 2;
enum Protocol {
UDP = 0;
TCP = 1;
HTTP = 2;
HTTPS = 3;
DTLS = 4;
}
}
message RelayConfig {
repeated string urls = 1;
string tokenPayload = 2;
string tokenSignature = 3;
}
message FlowConfig {
string url = 1;
string tokenPayload = 2;
string tokenSignature = 3;
google.protobuf.Duration interval = 4;
bool enabled = 5;
// counters determines if flow packets and bytes counters should be sent
bool counters = 6;
// exitNodeCollection determines if event collection on exit nodes should be enabled
bool exitNodeCollection = 7;
// dnsCollection determines if DNS event collection should be enabled
bool dnsCollection = 8;
}
// ProtectedHostConfig is similar to HostConfig but has additional user and password
// Mostly used for TURN servers
message ProtectedHostConfig {
HostConfig hostConfig = 1;
string user = 2;
string password = 3;
}
// PeerConfig represents a configuration of a "our" peer.
// The properties are used to configure local Wireguard
message PeerConfig {
// Peer's virtual IP address within the Netbird VPN (a Wireguard address config)
string address = 1;
// Netbird DNS server (a Wireguard DNS config)
string dns = 2;
// SSHConfig of the peer.
SSHConfig sshConfig = 3;
// Peer fully qualified domain name
string fqdn = 4;
bool RoutingPeerDnsResolutionEnabled = 5;
bool LazyConnectionEnabled = 6;
}
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
message NetworkMap {
// Serial is an ID of the network state to be used by clients to order updates.
// The larger the Serial the newer the configuration.
// E.g. the client app should keep track of this id locally and discard all the configurations with a lower value
uint64 Serial = 1;
// PeerConfig represents configuration of a peer
PeerConfig peerConfig = 2;
// RemotePeerConfig represents a list of remote peers that the receiver can connect to
repeated RemotePeerConfig remotePeers = 3;
// Indicates whether remotePeers array is empty or not to bypass protobuf null and empty array equality.
bool remotePeersIsEmpty = 4;
// List of routes to be applied
repeated Route Routes = 5;
// DNS config to be applied
DNSConfig DNSConfig = 6;
// RemotePeerConfig represents a list of remote peers that the receiver can connect to
repeated RemotePeerConfig offlinePeers = 7;
// FirewallRule represents a list of firewall rules to be applied to peer
repeated FirewallRule FirewallRules = 8;
// firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality.
bool firewallRulesIsEmpty = 9;
// RoutesFirewallRules represents a list of routes firewall rules to be applied to peer
repeated RouteFirewallRule routesFirewallRules = 10;
// RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality.
bool routesFirewallRulesIsEmpty = 11;
repeated ForwardingRule forwardingRules = 12;
}
// RemotePeerConfig represents a configuration of a remote peer.
// The properties are used to configure WireGuard Peers sections
message RemotePeerConfig {
// A WireGuard public key of a remote peer
string wgPubKey = 1;
// WireGuard allowed IPs of a remote peer e.g. [10.30.30.1/32]
repeated string allowedIps = 2;
// SSHConfig is a SSH config of the remote peer. SSHConfig.sshPubKey should be ignored because peer knows it's SSH key.
SSHConfig sshConfig = 3;
// Peer fully qualified domain name
string fqdn = 4;
string agentVersion = 5;
}
// SSHConfig represents SSH configurations of a peer.
message SSHConfig {
// sshEnabled indicates whether a SSH server is enabled on this peer
bool sshEnabled = 1;
// sshPubKey is a SSH public key of a peer to be added to authorized_hosts.
// This property should be ignore if SSHConfig comes from PeerConfig.
bytes sshPubKey = 2;
}
// DeviceAuthorizationFlowRequest empty struct for future expansion
message DeviceAuthorizationFlowRequest {}
// DeviceAuthorizationFlow represents Device Authorization Flow information
// that can be used by the client to login initiate a Oauth 2.0 device authorization grant flow
// see https://datatracker.ietf.org/doc/html/rfc8628
message DeviceAuthorizationFlow {
// An IDP provider , (eg. Auth0)
provider Provider = 1;
ProviderConfig ProviderConfig = 2;
enum provider {
HOSTED = 0;
}
}
// PKCEAuthorizationFlowRequest empty struct for future expansion
message PKCEAuthorizationFlowRequest {}
// PKCEAuthorizationFlow represents Authorization Code Flow information
// that can be used by the client to login initiate a Oauth 2.0 authorization code grant flow
// with Proof Key for Code Exchange (PKCE). See https://datatracker.ietf.org/doc/html/rfc7636
message PKCEAuthorizationFlow {
ProviderConfig ProviderConfig = 1;
}
// ProviderConfig has all attributes needed to initiate a device/pkce authorization flow
message ProviderConfig {
// An IDP application client id
string ClientID = 1;
// An IDP application client secret
string ClientSecret = 2;
// An IDP API domain
// Deprecated. Use a DeviceAuthEndpoint and TokenEndpoint
string Domain = 3;
// An Audience for validation
string Audience = 4;
// DeviceAuthEndpoint is an endpoint to request device authentication code.
string DeviceAuthEndpoint = 5;
// TokenEndpoint is an endpoint to request auth token.
string TokenEndpoint = 6;
// Scopes provides the scopes to be included in the token request
string Scope = 7;
// UseIDToken indicates if the id token should be used for authentication
bool UseIDToken = 8;
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code.
string AuthorizationEndpoint = 9;
// RedirectURLs handles authorization code from IDP manager
repeated string RedirectURLs = 10;
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
bool DisablePromptLogin = 11;
// LoginFlags sets the PKCE flow login details
uint32 LoginFlag = 12;
}
// Route represents a route.Route object
message Route {
string ID = 1;
string Network = 2;
int64 NetworkType = 3;
string Peer = 4;
int64 Metric = 5;
bool Masquerade = 6;
string NetID = 7;
repeated string Domains = 8;
bool keepRoute = 9;
}
// DNSConfig represents a dns.Update
message DNSConfig {
bool ServiceEnable = 1;
repeated NameServerGroup NameServerGroups = 2;
repeated CustomZone CustomZones = 3;
}
// CustomZone represents a dns.CustomZone
message CustomZone {
string Domain = 1;
repeated SimpleRecord Records = 2;
}
// SimpleRecord represents a dns.SimpleRecord
message SimpleRecord {
string Name = 1;
int64 Type = 2;
string Class = 3;
int64 TTL = 4;
string RData = 5;
}
// NameServerGroup represents a dns.NameServerGroup
message NameServerGroup {
repeated NameServer NameServers = 1;
bool Primary = 2;
repeated string Domains = 3;
bool SearchDomainsEnabled = 4;
}
// NameServer represents a dns.NameServer
message NameServer {
string IP = 1;
int64 NSType = 2;
int64 Port = 3;
}
enum RuleProtocol {
UNKNOWN = 0;
ALL = 1;
TCP = 2;
UDP = 3;
ICMP = 4;
CUSTOM = 5;
}
enum RuleDirection {
IN = 0;
OUT = 1;
}
enum RuleAction {
ACCEPT = 0;
DROP = 1;
}
// FirewallRule represents a firewall rule
message FirewallRule {
string PeerIP = 1;
RuleDirection Direction = 2;
RuleAction Action = 3;
RuleProtocol Protocol = 4;
string Port = 5;
PortInfo PortInfo = 6;
// PolicyID is the ID of the policy that this rule belongs to
bytes PolicyID = 7;
}
message NetworkAddress {
string netIP = 1;
string mac = 2;
}
message Checks {
repeated string Files = 1;
}
message PortInfo {
oneof portSelection {
uint32 port = 1;
Range range = 2;
}
message Range {
uint32 start = 1;
uint32 end = 2;
}
}
// RouteFirewallRule signifies a firewall rule applicable for a routed network.
message RouteFirewallRule {
// sourceRanges IP ranges of the routing peers.
repeated string sourceRanges = 1;
// Action to be taken by the firewall when the rule is applicable.
RuleAction action = 2;
// Network prefix for the routed network.
string destination = 3;
// Protocol of the routed network.
RuleProtocol protocol = 4;
// Details about the port.
PortInfo portInfo = 5;
// IsDynamic indicates if the route is a DNS route.
bool isDynamic = 6;
// Domains is a list of domains for which the rule is applicable.
repeated string domains = 7;
// CustomProtocol is a custom protocol ID.
uint32 customProtocol = 8;
// PolicyID is the ID of the policy that this rule belongs to
bytes PolicyID = 9;
// RouteID is the ID of the route that this rule belongs to
string RouteID = 10;
}
message ForwardingRule {
// Protocol of the forwarding rule
RuleProtocol protocol = 1;
// portInfo is the ingress destination port information, where the traffic arrives in the gateway node
PortInfo destinationPort = 2;
// IP address of the translated address (remote peer) to send traffic to
bytes translatedAddress = 3;
// Translated port information, where the traffic should be forwarded to
PortInfo translatedPort = 4;
}

View File

@@ -0,0 +1,429 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
package proto
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// ManagementServiceClient is the client API for ManagementService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type ManagementServiceClient interface {
// Login logs in peer. In case server returns codes.PermissionDenied this endpoint can be used to register Peer providing LoginRequest.setupKey
// Returns encrypted LoginResponse in EncryptedMessage.Body
Login(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
// Sync enables peer synchronization. Each peer that is connected to this stream will receive updates from the server.
// For example, if a new peer has been added to an account all other connected peers will receive this peer's Wireguard public key as an update
// The initial SyncResponse contains all of the available peers so the local state can be refreshed
// Returns encrypted SyncResponse in EncryptedMessage.Body
Sync(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (ManagementService_SyncClient, error)
// Exposes a Wireguard public key of the Management service.
// This key is used to support message encryption between client and server
GetServerKey(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*ServerKeyResponse, error)
// health check endpoint
IsHealthy(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
// Exposes a device authorization flow information
// This is used for initiating a Oauth 2 device authorization grant flow
// which will be used by our clients to Login.
// EncryptedMessage of the request has a body of DeviceAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of DeviceAuthorizationFlow.
GetDeviceAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
// Exposes a PKCE authorization code flow information
// This is used for initiating a Oauth 2 authorization grant flow
// with Proof Key for Code Exchange (PKCE) which will be used by our clients to Login.
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
GetPKCEAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
// Logout logs out the peer and removes it from the management server
Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
}
type managementServiceClient struct {
cc grpc.ClientConnInterface
}
func NewManagementServiceClient(cc grpc.ClientConnInterface) ManagementServiceClient {
return &managementServiceClient{cc}
}
func (c *managementServiceClient) Login(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
out := new(EncryptedMessage)
err := c.cc.Invoke(ctx, "/management.ManagementService/Login", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) Sync(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (ManagementService_SyncClient, error) {
stream, err := c.cc.NewStream(ctx, &ManagementService_ServiceDesc.Streams[0], "/management.ManagementService/Sync", opts...)
if err != nil {
return nil, err
}
x := &managementServiceSyncClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type ManagementService_SyncClient interface {
Recv() (*EncryptedMessage, error)
grpc.ClientStream
}
type managementServiceSyncClient struct {
grpc.ClientStream
}
func (x *managementServiceSyncClient) Recv() (*EncryptedMessage, error) {
m := new(EncryptedMessage)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func (c *managementServiceClient) GetServerKey(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*ServerKeyResponse, error) {
out := new(ServerKeyResponse)
err := c.cc.Invoke(ctx, "/management.ManagementService/GetServerKey", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) IsHealthy(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := c.cc.Invoke(ctx, "/management.ManagementService/isHealthy", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) GetDeviceAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
out := new(EncryptedMessage)
err := c.cc.Invoke(ctx, "/management.ManagementService/GetDeviceAuthorizationFlow", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) GetPKCEAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
out := new(EncryptedMessage)
err := c.cc.Invoke(ctx, "/management.ManagementService/GetPKCEAuthorizationFlow", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := c.cc.Invoke(ctx, "/management.ManagementService/SyncMeta", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *managementServiceClient) Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) {
out := new(Empty)
err := c.cc.Invoke(ctx, "/management.ManagementService/Logout", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// ManagementServiceServer is the server API for ManagementService service.
// All implementations must embed UnimplementedManagementServiceServer
// for forward compatibility
type ManagementServiceServer interface {
// Login logs in peer. In case server returns codes.PermissionDenied this endpoint can be used to register Peer providing LoginRequest.setupKey
// Returns encrypted LoginResponse in EncryptedMessage.Body
Login(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
// Sync enables peer synchronization. Each peer that is connected to this stream will receive updates from the server.
// For example, if a new peer has been added to an account all other connected peers will receive this peer's Wireguard public key as an update
// The initial SyncResponse contains all of the available peers so the local state can be refreshed
// Returns encrypted SyncResponse in EncryptedMessage.Body
Sync(*EncryptedMessage, ManagementService_SyncServer) error
// Exposes a Wireguard public key of the Management service.
// This key is used to support message encryption between client and server
GetServerKey(context.Context, *Empty) (*ServerKeyResponse, error)
// health check endpoint
IsHealthy(context.Context, *Empty) (*Empty, error)
// Exposes a device authorization flow information
// This is used for initiating a Oauth 2 device authorization grant flow
// which will be used by our clients to Login.
// EncryptedMessage of the request has a body of DeviceAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of DeviceAuthorizationFlow.
GetDeviceAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
// Exposes a PKCE authorization code flow information
// This is used for initiating a Oauth 2 authorization grant flow
// with Proof Key for Code Exchange (PKCE) which will be used by our clients to Login.
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
// SyncMeta is used to sync metadata of the peer.
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
// sync meta will evaluate the checks and update the peer meta with the result.
// EncryptedMessage of the request has a body of Empty.
SyncMeta(context.Context, *EncryptedMessage) (*Empty, error)
// Logout logs out the peer and removes it from the management server
Logout(context.Context, *EncryptedMessage) (*Empty, error)
mustEmbedUnimplementedManagementServiceServer()
}
// UnimplementedManagementServiceServer must be embedded to have forward compatible implementations.
type UnimplementedManagementServiceServer struct {
}
func (UnimplementedManagementServiceServer) Login(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method Login not implemented")
}
func (UnimplementedManagementServiceServer) Sync(*EncryptedMessage, ManagementService_SyncServer) error {
return status.Errorf(codes.Unimplemented, "method Sync not implemented")
}
func (UnimplementedManagementServiceServer) GetServerKey(context.Context, *Empty) (*ServerKeyResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetServerKey not implemented")
}
func (UnimplementedManagementServiceServer) IsHealthy(context.Context, *Empty) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method IsHealthy not implemented")
}
func (UnimplementedManagementServiceServer) GetDeviceAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetDeviceAuthorizationFlow not implemented")
}
func (UnimplementedManagementServiceServer) GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented")
}
func (UnimplementedManagementServiceServer) SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented")
}
func (UnimplementedManagementServiceServer) Logout(context.Context, *EncryptedMessage) (*Empty, error) {
return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented")
}
func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {}
// UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to ManagementServiceServer will
// result in compilation errors.
type UnsafeManagementServiceServer interface {
mustEmbedUnimplementedManagementServiceServer()
}
func RegisterManagementServiceServer(s grpc.ServiceRegistrar, srv ManagementServiceServer) {
s.RegisterService(&ManagementService_ServiceDesc, srv)
}
func _ManagementService_Login_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).Login(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/Login",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).Login(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_Sync_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(EncryptedMessage)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(ManagementServiceServer).Sync(m, &managementServiceSyncServer{stream})
}
type ManagementService_SyncServer interface {
Send(*EncryptedMessage) error
grpc.ServerStream
}
type managementServiceSyncServer struct {
grpc.ServerStream
}
func (x *managementServiceSyncServer) Send(m *EncryptedMessage) error {
return x.ServerStream.SendMsg(m)
}
func _ManagementService_GetServerKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).GetServerKey(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/GetServerKey",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).GetServerKey(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_IsHealthy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(Empty)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).IsHealthy(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/isHealthy",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).IsHealthy(ctx, req.(*Empty))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_GetDeviceAuthorizationFlow_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).GetDeviceAuthorizationFlow(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/GetDeviceAuthorizationFlow",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).GetDeviceAuthorizationFlow(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_GetPKCEAuthorizationFlow_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).GetPKCEAuthorizationFlow(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/GetPKCEAuthorizationFlow",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).GetPKCEAuthorizationFlow(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_SyncMeta_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).SyncMeta(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/SyncMeta",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).SyncMeta(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
func _ManagementService_Logout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ManagementServiceServer).Logout(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/management.ManagementService/Logout",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ManagementServiceServer).Logout(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
// ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var ManagementService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "management.ManagementService",
HandlerType: (*ManagementServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Login",
Handler: _ManagementService_Login_Handler,
},
{
MethodName: "GetServerKey",
Handler: _ManagementService_GetServerKey_Handler,
},
{
MethodName: "isHealthy",
Handler: _ManagementService_IsHealthy_Handler,
},
{
MethodName: "GetDeviceAuthorizationFlow",
Handler: _ManagementService_GetDeviceAuthorizationFlow_Handler,
},
{
MethodName: "GetPKCEAuthorizationFlow",
Handler: _ManagementService_GetPKCEAuthorizationFlow_Handler,
},
{
MethodName: "SyncMeta",
Handler: _ManagementService_SyncMeta_Handler,
},
{
MethodName: "Logout",
Handler: _ManagementService_Logout_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "Sync",
Handler: _ManagementService_Sync_Handler,
ServerStreams: true,
},
},
Metadata: "management.proto",
}

View File

@@ -0,0 +1,14 @@
package allow
// Auth is a Validator that allows all connections.
// Used this for testing purposes only.
type Auth struct {
}
func (a *Auth) Validate(any) error {
return nil
}
func (a *Auth) ValidateHelloMsgType(any) error {
return nil
}

26
shared/relay/auth/doc.go Normal file
View File

@@ -0,0 +1,26 @@
/*
Package auth manages the authentication process with the relay server.
Key Components:
Validator: The Validator interface defines the Validate method. Any type that provides this method can be used as a
Validator.
Methods:
Validate(func() hash.Hash, any): This method is defined in the Validator interface and is used to validate the authentication.
Usage:
To create a new AllowAllAuth validator, simply instantiate it:
validator := &allow.Auth{}
To validate the authentication, use the Validate method:
err := validator.Validate(sha256.New, any)
This package provides a simple and effective way to manage authentication with the relay server, ensuring that the
peers are authenticated properly.
*/
package auth

1
shared/relay/auth/go.sum Normal file
View File

@@ -0,0 +1 @@
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=

View File

@@ -0,0 +1,8 @@
/*
This package uses a similar HMAC method for authentication with the TURN server. The Management server provides the
tokens for the peers. The peers manage these tokens in the token store. The token store is a simple thread safe store
that keeps the tokens in memory. These tokens are used to authenticate the peers with the Relay server in the hello
message.
*/
package hmac

View File

@@ -0,0 +1,44 @@
package hmac
import (
"encoding/base64"
"fmt"
"sync"
v2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
)
// TokenStore is a simple in-memory store for token
// With this can update the token in thread safe way
type TokenStore struct {
mu sync.Mutex
token []byte
}
func (a *TokenStore) UpdateToken(token *Token) error {
a.mu.Lock()
defer a.mu.Unlock()
if token == nil {
return nil
}
sig, err := base64.StdEncoding.DecodeString(token.Signature)
if err != nil {
return fmt.Errorf("decode signature: %w", err)
}
tok := v2.Token{
AuthAlgo: v2.AuthAlgoHMACSHA256,
Signature: sig,
Payload: []byte(token.Payload),
}
a.token = tok.Marshal()
return nil
}
func (a *TokenStore) TokenBinary() []byte {
a.mu.Lock()
defer a.mu.Unlock()
return a.token
}

View File

@@ -0,0 +1,94 @@
package hmac
import (
"bytes"
"crypto/hmac"
"encoding/base64"
"encoding/gob"
"fmt"
"hash"
"strconv"
"time"
log "github.com/sirupsen/logrus"
)
type Token struct {
Payload string
Signature string
}
func unmarshalToken(payload []byte) (Token, error) {
var creds Token
buffer := bytes.NewBuffer(payload)
decoder := gob.NewDecoder(buffer)
err := decoder.Decode(&creds)
return creds, err
}
// TimedHMAC generates a token with TTL and uses a pre-shared secret known to the relay server
type TimedHMAC struct {
secret string
timeToLive time.Duration
}
// NewTimedHMAC creates a new TimedHMAC instance
func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC {
return &TimedHMAC{
secret: secret,
timeToLive: timeToLive,
}
}
// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC
// hash of a timestamp with a preshared TURN secret
func (m *TimedHMAC) GenerateToken(algo func() hash.Hash) (*Token, error) {
timeAuth := time.Now().Add(m.timeToLive).Unix()
timeStamp := strconv.FormatInt(timeAuth, 10)
checksum, err := m.generate(algo, timeStamp)
if err != nil {
return nil, err
}
return &Token{
Payload: timeStamp,
Signature: base64.StdEncoding.EncodeToString(checksum),
}, nil
}
// Validate checks if the token is valid
func (m *TimedHMAC) Validate(algo func() hash.Hash, token Token) error {
expectedMAC, err := m.generate(algo, token.Payload)
if err != nil {
return err
}
expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC)
if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) {
return fmt.Errorf("signature mismatch")
}
timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64)
if err != nil {
return fmt.Errorf("invalid payload: %w", err)
}
if time.Now().Unix() > timeAuthInt {
return fmt.Errorf("expired token")
}
return nil
}
func (m *TimedHMAC) generate(algo func() hash.Hash, payload string) ([]byte, error) {
mac := hmac.New(algo, []byte(m.secret))
_, err := mac.Write([]byte(payload))
if err != nil {
log.Debugf("failed to generate token: %s", err)
return nil, fmt.Errorf("failed to generate token: %w", err)
}
return mac.Sum(nil), nil
}

View File

@@ -0,0 +1,105 @@
package hmac
import (
"crypto/sha1"
"crypto/sha256"
"encoding/base64"
"strconv"
"testing"
"time"
)
func TestGenerateCredentials(t *testing.T) {
secret := "secret"
timeToLive := 1 * time.Hour
v := NewTimedHMAC(secret, timeToLive)
creds, err := v.GenerateToken(sha1.New)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if creds.Payload == "" {
t.Fatalf("expected non-empty payload")
}
_, err = strconv.ParseInt(creds.Payload, 10, 64)
if err != nil {
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
}
_, err = base64.StdEncoding.DecodeString(creds.Signature)
if err != nil {
t.Fatalf("expected signature to be base64 encoded, got %v", err)
}
}
func TestValidateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
manager := NewTimedHMAC(secret, timeToLive)
// Test valid token
creds, err := manager.GenerateToken(sha1.New)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if err := manager.Validate(sha1.New, *creds); err != nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidSignature(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
manager := NewTimedHMAC(secret, timeToLive)
creds, err := manager.GenerateToken(sha256.New)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
invalidCreds := &Token{
Payload: creds.Payload,
Signature: "invalidsignature",
}
if err = manager.Validate(sha1.New, *invalidCreds); err == nil {
t.Fatalf("expected invalid token due to signature mismatch")
}
}
func TestExpired(t *testing.T) {
secret := "supersecret"
v := NewTimedHMAC(secret, -1*time.Hour)
expiredCreds, err := v.GenerateToken(sha256.New)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if err = v.Validate(sha1.New, *expiredCreds); err == nil {
t.Fatalf("expected invalid token due to expiration")
}
}
func TestInvalidPayload(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
v := NewTimedHMAC(secret, timeToLive)
creds, err := v.GenerateToken(sha256.New)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
// Test invalid payload
invalidPayloadCreds := &Token{
Payload: "invalidtimestamp",
Signature: creds.Signature,
}
if err = v.Validate(sha1.New, *invalidPayloadCreds); err == nil {
t.Fatalf("expected invalid token due to invalid payload")
}
}

View File

@@ -0,0 +1,40 @@
package v2
import (
"crypto/sha256"
"hash"
)
const (
AuthAlgoUnknown AuthAlgo = iota
AuthAlgoHMACSHA256
)
type AuthAlgo uint8
func (a AuthAlgo) String() string {
switch a {
case AuthAlgoHMACSHA256:
return "HMAC-SHA256"
default:
return "Unknown"
}
}
func (a AuthAlgo) New() func() hash.Hash {
switch a {
case AuthAlgoHMACSHA256:
return sha256.New
default:
return nil
}
}
func (a AuthAlgo) Size() int {
switch a {
case AuthAlgoHMACSHA256:
return sha256.Size
default:
return 0
}
}

View File

@@ -0,0 +1,45 @@
package v2
import (
"crypto/hmac"
"fmt"
"hash"
"strconv"
"time"
)
type Generator struct {
algo func() hash.Hash
algoType AuthAlgo
secret []byte
timeToLive time.Duration
}
func NewGenerator(algo AuthAlgo, secret []byte, timeToLive time.Duration) (*Generator, error) {
algoFunc := algo.New()
if algoFunc == nil {
return nil, fmt.Errorf("unsupported auth algorithm: %s", algo)
}
return &Generator{
algo: algoFunc,
algoType: algo,
secret: secret,
timeToLive: timeToLive,
}, nil
}
func (g *Generator) GenerateToken() (*Token, error) {
expirationTime := time.Now().Add(g.timeToLive).Unix()
payload := []byte(strconv.FormatInt(expirationTime, 10))
h := hmac.New(g.algo, g.secret)
h.Write(payload)
signature := h.Sum(nil)
return &Token{
AuthAlgo: g.algoType,
Signature: signature,
Payload: payload,
}, nil
}

View File

@@ -0,0 +1,110 @@
package v2
import (
"strconv"
"testing"
"time"
)
func TestGenerateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if len(token.Payload) == 0 {
t.Fatalf("expected non-empty payload")
}
_, err = strconv.ParseInt(string(token.Payload), 10, 64)
if err != nil {
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
}
}
func TestValidateCredentials(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err != nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidSignature(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
token.Signature = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestExpired(t *testing.T) {
secret := "supersecret"
timeToLive := -1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected valid token: %s", err)
}
}
func TestInvalidPayload(t *testing.T) {
secret := "supersecret"
timeToLive := 1 * time.Hour
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
if err != nil {
t.Fatalf("failed to create generator: %v", err)
}
token, err := g.GenerateToken()
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
token.Payload = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
v := NewValidator([]byte(secret))
if err := v.Validate(token.Marshal()); err == nil {
t.Fatalf("expected invalid token due to invalid payload")
}
}

View File

@@ -0,0 +1,39 @@
package v2
import "errors"
type Token struct {
AuthAlgo AuthAlgo
Signature []byte
Payload []byte
}
func (t *Token) Marshal() []byte {
size := 1 + len(t.Signature) + len(t.Payload)
buf := make([]byte, size)
buf[0] = byte(t.AuthAlgo)
copy(buf[1:], t.Signature)
copy(buf[1+len(t.Signature):], t.Payload)
return buf
}
func UnmarshalToken(data []byte) (*Token, error) {
if len(data) == 0 {
return nil, errors.New("invalid token data")
}
algo := AuthAlgo(data[0])
sigSize := algo.Size()
if len(data) < 1+sigSize {
return nil, errors.New("invalid token data: insufficient length")
}
return &Token{
AuthAlgo: algo,
Signature: data[1 : 1+sigSize],
Payload: data[1+sigSize:],
}, nil
}

View File

@@ -0,0 +1,59 @@
package v2
import (
"crypto/hmac"
"errors"
"fmt"
"strconv"
"time"
)
const minLengthUnixTimestamp = 10
type Validator struct {
secret []byte
}
func NewValidator(secret []byte) *Validator {
return &Validator{secret: secret}
}
func (v *Validator) Validate(data any) error {
d, ok := data.([]byte)
if !ok {
return fmt.Errorf("invalid data type")
}
token, err := UnmarshalToken(d)
if err != nil {
return fmt.Errorf("unmarshal token: %w", err)
}
if len(token.Payload) < minLengthUnixTimestamp {
return errors.New("invalid payload: insufficient length")
}
hashFunc := token.AuthAlgo.New()
if hashFunc == nil {
return fmt.Errorf("unsupported auth algorithm: %s", token.AuthAlgo)
}
h := hmac.New(hashFunc, v.secret)
h.Write(token.Payload)
expectedMAC := h.Sum(nil)
if !hmac.Equal(token.Signature, expectedMAC) {
return errors.New("invalid signature")
}
timestamp, err := strconv.ParseInt(string(token.Payload), 10, 64)
if err != nil {
return fmt.Errorf("invalid payload: %w", err)
}
if time.Now().Unix() > timestamp {
return fmt.Errorf("expired token")
}
return nil
}

View File

@@ -0,0 +1,33 @@
package hmac
import (
"crypto/sha256"
"fmt"
"time"
log "github.com/sirupsen/logrus"
)
type TimedHMACValidator struct {
*TimedHMAC
}
func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACValidator {
ta := NewTimedHMAC(secret, duration)
return &TimedHMACValidator{
ta,
}
}
func (a *TimedHMACValidator) Validate(credentials any) error {
b, ok := credentials.([]byte)
if !ok {
return fmt.Errorf("invalid credentials type")
}
c, err := unmarshalToken(b)
if err != nil {
log.Debugf("failed to unmarshal token: %s", err)
return err
}
return a.TimedHMAC.Validate(sha256.New, c)
}

View File

@@ -0,0 +1,28 @@
package auth
import (
"time"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
)
type TimedHMACValidator struct {
authenticatorV2 *authv2.Validator
authenticator *auth.TimedHMACValidator
}
func NewTimedHMACValidator(secret []byte, duration time.Duration) *TimedHMACValidator {
return &TimedHMACValidator{
authenticatorV2: authv2.NewValidator(secret),
authenticator: auth.NewTimedHMACValidator(string(secret), duration),
}
}
func (a *TimedHMACValidator) Validate(credentials any) error {
return a.authenticatorV2.Validate(credentials)
}
func (a *TimedHMACValidator) ValidateHelloMsgType(credentials any) error {
return a.authenticator.Validate(credentials)
}

View File

@@ -0,0 +1,13 @@
package client
type RelayAddr struct {
addr string
}
func (a RelayAddr) Network() string {
return "relay"
}
func (a RelayAddr) String() string {
return a.addr
}

View File

@@ -0,0 +1,665 @@
package client
import (
"context"
"fmt"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
"github.com/netbirdio/netbird/shared/relay/client/dialer"
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
"github.com/netbirdio/netbird/relay/healthcheck"
"github.com/netbirdio/netbird/relay/messages"
)
const (
bufferSize = 8820
serverResponseTimeout = 8 * time.Second
)
var (
ErrConnAlreadyExists = fmt.Errorf("connection already exists")
)
type internalStopFlag struct {
sync.Mutex
stop bool
}
func newInternalStopFlag() *internalStopFlag {
return &internalStopFlag{}
}
func (isf *internalStopFlag) set() {
isf.Lock()
defer isf.Unlock()
isf.stop = true
}
func (isf *internalStopFlag) isSet() bool {
isf.Lock()
defer isf.Unlock()
return isf.stop
}
// Msg carry the payload from the server to the client. With this struct, the net.Conn can free the buffer.
type Msg struct {
Payload []byte
bufPool *sync.Pool
bufPtr *[]byte
}
func (m *Msg) Free() {
m.bufPool.Put(m.bufPtr)
}
// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
// server and forwarding them to the upper layer content reader.
type connContainer struct {
log *log.Entry
conn *Conn
messages chan Msg
msgChanLock sync.Mutex
closed bool // flag to check if channel is closed
ctx context.Context
cancel context.CancelFunc
}
func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
ctx, cancel := context.WithCancel(context.Background())
return &connContainer{
log: log,
conn: conn,
messages: messages,
ctx: ctx,
cancel: cancel,
}
}
func (cc *connContainer) writeMsg(msg Msg) {
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
msg.Free()
return
}
select {
case cc.messages <- msg:
case <-cc.ctx.Done():
msg.Free()
default:
msg.Free()
}
}
func (cc *connContainer) close() {
cc.cancel()
cc.msgChanLock.Lock()
defer cc.msgChanLock.Unlock()
if cc.closed {
return
}
cc.closed = true
close(cc.messages)
for msg := range cc.messages {
msg.Free()
}
}
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
type Client struct {
log *log.Entry
connectionURL string
authTokenStore *auth.TokenStore
hashedID messages.PeerID
bufPool *sync.Pool
relayConn net.Conn
conns map[messages.PeerID]*connContainer
serviceIsRunning bool
mu sync.Mutex // protect serviceIsRunning and conns
readLoopMutex sync.Mutex
wgReadLoop sync.WaitGroup
instanceURL *RelayAddr
muInstanceURL sync.Mutex
onDisconnectListener func(string)
listenerMutex sync.Mutex
stateSubscription *PeersStateSubscription
}
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
hashedID := messages.HashID(peerID)
relayLog := log.WithFields(log.Fields{"relay": serverURL})
c := &Client{
log: relayLog,
connectionURL: serverURL,
authTokenStore: authTokenStore,
hashedID: hashedID,
bufPool: &sync.Pool{
New: func() any {
buf := make([]byte, bufferSize)
return &buf
},
},
conns: make(map[messages.PeerID]*connContainer),
}
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedID)
return c
}
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
func (c *Client) Connect(ctx context.Context) error {
c.log.Infof("connecting to relay server")
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
c.mu.Lock()
defer c.mu.Unlock()
if c.serviceIsRunning {
return nil
}
instanceURL, err := c.connect(ctx)
if err != nil {
return err
}
c.muInstanceURL.Lock()
c.instanceURL = instanceURL
c.muInstanceURL.Unlock()
c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)
c.log = c.log.WithField("relay", instanceURL.String())
c.log.Infof("relay connection established")
c.serviceIsRunning = true
internallyStoppedFlag := newInternalStopFlag()
hc := healthcheck.NewReceiver(c.log)
go c.listenForStopEvents(ctx, hc, c.relayConn, internallyStoppedFlag)
c.wgReadLoop.Add(1)
go c.readLoop(hc, c.relayConn, internallyStoppedFlag)
return nil
}
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
// it will return immediately.
// It block until the server confirm the peer is online.
// todo: what should happen if call with the same peerID with multiple times?
func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) {
peerID := messages.HashID(dstPeerID)
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established")
}
_, ok := c.conns[peerID]
if ok {
c.mu.Unlock()
return nil, ErrConnAlreadyExists
}
c.mu.Unlock()
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
c.log.Errorf("peer not available: %s, %s", peerID, err)
return nil, err
}
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
msgChannel := make(chan Msg, 100)
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
return nil, fmt.Errorf("relay connection is not established")
}
c.muInstanceURL.Lock()
instanceURL := c.instanceURL
c.muInstanceURL.Unlock()
conn := NewConn(c, peerID, msgChannel, instanceURL)
_, ok = c.conns[peerID]
if ok {
c.mu.Unlock()
_ = conn.Close()
return nil, ErrConnAlreadyExists
}
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
c.mu.Unlock()
return conn, nil
}
// ServerInstanceURL returns the address of the relay server. It could change after the close and reopen the connection.
func (c *Client) ServerInstanceURL() (string, error) {
c.muInstanceURL.Lock()
defer c.muInstanceURL.Unlock()
if c.instanceURL == nil {
return "", fmt.Errorf("relay connection is not established")
}
return c.instanceURL.String(), nil
}
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
func (c *Client) SetOnDisconnectListener(fn func(string)) {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
c.onDisconnectListener = fn
}
// HasConns returns true if there are connections.
func (c *Client) HasConns() bool {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.conns) > 0
}
func (c *Client) Ready() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.serviceIsRunning
}
// Close closes the connection to the relay server and all connections to other peers.
func (c *Client) Close() error {
return c.close(true)
}
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
conn, err := rd.Dial()
if err != nil {
return nil, err
}
c.relayConn = conn
instanceURL, err := c.handShake(ctx)
if err != nil {
cErr := conn.Close()
if cErr != nil {
c.log.Errorf("failed to close connection: %s", cErr)
}
return nil, err
}
return instanceURL, nil
}
func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
if err != nil {
c.log.Errorf("failed to marshal auth message: %s", err)
return nil, err
}
_, err = c.relayConn.Write(msg)
if err != nil {
c.log.Errorf("failed to send auth message: %s", err)
return nil, err
}
buf := make([]byte, messages.MaxHandshakeRespSize)
n, err := c.readWithTimeout(ctx, buf)
if err != nil {
c.log.Errorf("failed to read auth response: %s", err)
return nil, err
}
_, err = messages.ValidateVersion(buf[:n])
if err != nil {
return nil, fmt.Errorf("validate version: %w", err)
}
msgType, err := messages.DetermineServerMessageType(buf[:n])
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
return nil, err
}
if msgType != messages.MsgTypeAuthResponse {
c.log.Errorf("unexpected message type: %s", msgType)
return nil, fmt.Errorf("unexpected message type")
}
addr, err := messages.UnmarshalAuthResponse(buf[:n])
if err != nil {
return nil, err
}
return &RelayAddr{addr: addr}, nil
}
func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) {
var (
errExit error
n int
)
for {
bufPtr := c.bufPool.Get().(*[]byte)
buf := *bufPtr
n, errExit = relayConn.Read(buf)
if errExit != nil {
c.log.Infof("start to Relay read loop exit")
c.mu.Lock()
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Errorf("failed to read message from relay server: %s", errExit)
}
c.mu.Unlock()
c.bufPool.Put(bufPtr)
break
}
buf = buf[:n]
_, err := messages.ValidateVersion(buf)
if err != nil {
c.log.Errorf("failed to validate protocol version: %s", err)
c.bufPool.Put(bufPtr)
continue
}
msgType, err := messages.DetermineServerMessageType(buf)
if err != nil {
c.log.Errorf("failed to determine message type: %s", err)
c.bufPool.Put(bufPtr)
continue
}
if !c.handleMsg(msgType, buf, bufPtr, hc, internallyStoppedFlag) {
break
}
}
hc.Stop()
c.stateSubscription.Cleanup()
c.wgReadLoop.Done()
_ = c.close(false)
c.notifyDisconnected()
}
func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) {
switch msgType {
case messages.MsgTypeHealthCheck:
c.handleHealthCheck(hc, internallyStoppedFlag)
c.bufPool.Put(bufPtr)
case messages.MsgTypeTransport:
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
case messages.MsgTypePeersOnline:
c.handlePeersOnlineMsg(buf)
c.bufPool.Put(bufPtr)
return true
case messages.MsgTypePeersWentOffline:
c.handlePeersWentOfflineMsg(buf)
c.bufPool.Put(bufPtr)
return true
case messages.MsgTypeClose:
c.log.Debugf("relay connection close by server")
c.bufPool.Put(bufPtr)
return false
}
return true
}
func (c *Client) handleHealthCheck(hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) {
msg := messages.MarshalHealthcheck()
_, wErr := c.relayConn.Write(msg)
if wErr != nil {
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Errorf("failed to send heartbeat: %s", wErr)
}
}
hc.Heartbeat()
}
func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppedFlag *internalStopFlag) bool {
peerID, payload, err := messages.UnmarshalTransportMsg(buf)
if err != nil {
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
c.log.Errorf("failed to parse transport message: %v", err)
}
c.bufPool.Put(bufPtr)
return true
}
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
c.bufPool.Put(bufPtr)
return false
}
container, ok := c.conns[*peerID]
c.mu.Unlock()
if !ok {
c.log.Errorf("peer not found: %s", peerID.String())
c.bufPool.Put(bufPtr)
return true
}
msg := Msg{
bufPool: c.bufPool,
bufPtr: bufPtr,
Payload: payload,
}
container.writeMsg(msg)
return true
}
func (c *Client) writeTo(connReference *Conn, dstID messages.PeerID, payload []byte) (int, error) {
c.mu.Lock()
conn, ok := c.conns[dstID]
c.mu.Unlock()
if !ok {
return 0, net.ErrClosed
}
if conn.conn != connReference {
return 0, net.ErrClosed
}
// todo: use buffer pool instead of create new transport msg.
msg, err := messages.MarshalTransportMsg(dstID, payload)
if err != nil {
c.log.Errorf("failed to marshal transport message: %s", err)
return 0, err
}
// the write always return with 0 length because the underling does not support the size feedback.
_, err = c.relayConn.Write(msg)
if err != nil {
c.log.Errorf("failed to write transport message: %s", err)
}
return len(payload), err
}
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
for {
select {
case _, ok := <-hc.OnTimeout:
if !ok {
return
}
c.log.Errorf("health check timeout")
internalStopFlag.set()
if err := conn.Close(); err != nil {
// ignore the err handling because the readLoop will handle it
c.log.Warnf("failed to close connection: %s", err)
}
return
case <-ctx.Done():
err := c.close(true)
if err != nil {
c.log.Errorf("failed to teardown connection: %s", err)
}
return
}
}
}
func (c *Client) closeAllConns() {
for _, container := range c.conns {
container.close()
}
c.conns = make(map[messages.PeerID]*connContainer)
}
func (c *Client) closeConnsByPeerID(peerIDs []messages.PeerID) {
c.mu.Lock()
defer c.mu.Unlock()
for _, peerID := range peerIDs {
container, ok := c.conns[peerID]
if !ok {
c.log.Warnf("can not close connection, peer not found: %s", peerID)
continue
}
container.log.Infof("remote peer has been disconnected, free up connection: %s", peerID)
container.close()
delete(c.conns, peerID)
}
if err := c.stateSubscription.UnsubscribeStateChange(peerIDs); err != nil {
c.log.Errorf("failed to unsubscribe from peer state change: %s, %s", peerIDs, err)
}
}
func (c *Client) closeConn(connReference *Conn, id messages.PeerID) error {
c.mu.Lock()
defer c.mu.Unlock()
container, ok := c.conns[id]
if !ok {
return net.ErrClosed
}
if container.conn != connReference {
return fmt.Errorf("conn reference mismatch")
}
if err := c.stateSubscription.UnsubscribeStateChange([]messages.PeerID{id}); err != nil {
container.log.Errorf("failed to unsubscribe from peer state change: %s", err)
}
c.log.Infof("free up connection to peer: %s", id)
delete(c.conns, id)
container.close()
return nil
}
func (c *Client) close(gracefullyExit bool) error {
c.readLoopMutex.Lock()
defer c.readLoopMutex.Unlock()
c.mu.Lock()
var err error
if !c.serviceIsRunning {
c.mu.Unlock()
c.log.Warn("relay connection was already marked as not running")
return nil
}
c.serviceIsRunning = false
c.muInstanceURL.Lock()
c.instanceURL = nil
c.muInstanceURL.Unlock()
c.log.Infof("closing all peer connections")
c.closeAllConns()
if gracefullyExit {
c.writeCloseMsg()
}
err = c.relayConn.Close()
c.mu.Unlock()
c.log.Infof("waiting for read loop to close")
c.wgReadLoop.Wait()
c.log.Infof("relay connection closed")
return err
}
func (c *Client) notifyDisconnected() {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
if c.onDisconnectListener == nil {
return
}
go c.onDisconnectListener(c.connectionURL)
}
func (c *Client) writeCloseMsg() {
msg := messages.MarshalCloseMsg()
_, err := c.relayConn.Write(msg)
if err != nil {
c.log.Errorf("failed to send close message: %s", err)
}
}
func (c *Client) readWithTimeout(ctx context.Context, buf []byte) (int, error) {
ctx, cancel := context.WithTimeout(ctx, serverResponseTimeout)
defer cancel()
readDone := make(chan struct{})
var (
n int
err error
)
go func() {
n, err = c.relayConn.Read(buf)
close(readDone)
}()
select {
case <-ctx.Done():
return 0, fmt.Errorf("read operation timed out")
case <-readDone:
return n, err
}
}
func (c *Client) handlePeersOnlineMsg(buf []byte) {
peersID, err := messages.UnmarshalPeersOnlineMsg(buf)
if err != nil {
c.log.Errorf("failed to unmarshal peers online msg: %s", err)
return
}
c.stateSubscription.OnPeersOnline(peersID)
}
func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
peersID, err := messages.UnMarshalPeersWentOffline(buf)
if err != nil {
c.log.Errorf("failed to unmarshal peers went offline msg: %s", err)
return
}
c.stateSubscription.OnPeersWentOffline(peersID)
}

View File

@@ -0,0 +1,743 @@
package client
import (
"context"
"net"
"os"
"testing"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/shared/relay/auth/allow"
"github.com/netbirdio/netbird/shared/relay/auth/hmac"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/relay/server"
)
var (
hmacTokenStore = &hmac.TokenStore{}
serverListenAddr = "127.0.0.1:1234"
serverURL = "rel://127.0.0.1:1234"
serverCfg = server.Config{
Meter: otel.Meter(""),
ExposedAddress: serverURL,
TLSSupport: false,
AuthValidator: &allow.Auth{},
}
)
func TestMain(m *testing.M) {
_ = util.InitLog("debug", util.LogConsole)
code := m.Run()
os.Exit(code)
}
func TestClient(t *testing.T) {
ctx := context.Background()
srv, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
listenCfg := server.ListenerConfig{Address: serverListenAddr}
err := srv.Listen(listenCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for server to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
t.Log("alice connecting to server")
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientAlice.Close()
t.Log("placeholder connecting to server")
clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder")
err = clientPlaceHolder.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientPlaceHolder.Close()
t.Log("Bob connecting to server")
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientBob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer clientBob.Close()
t.Log("Alice open connection to Bob")
connAliceToBob, err := clientAlice.OpenConn(ctx, "bob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
t.Log("Bob open connection to Alice")
connBobToAlice, err := clientBob.OpenConn(ctx, "alice")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
log.Debugf("alice sent message to bob")
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
log.Debugf("on new message from alice to bob")
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestRegistration(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
// wait for server to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
_ = srv.Shutdown(ctx)
t.Fatalf("failed to connect to server: %s", err)
}
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
err = srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}
func TestRegistrationTimeout(t *testing.T) {
ctx := context.Background()
fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind UDP server: %s", err)
}
defer func(fakeUDPListener *net.UDPConn) {
_ = fakeUDPListener.Close()
}(fakeUDPListener)
fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
Port: 1234,
IP: net.ParseIP("0.0.0.0"),
})
if err != nil {
t.Fatalf("failed to bind TCP server: %s", err)
}
defer func(fakeTCPListener *net.TCPListener) {
_ = fakeTCPListener.Close()
}(fakeTCPListener)
clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err == nil {
t.Errorf("failed to connect to server: %s", err)
}
log.Debugf("%s", err)
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close conn: %s", err)
}
}
func TestEcho(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close Alice client: %s", err)
}
}()
clientBob := NewClient(serverURL, hmacTokenStore, idBob)
err = clientBob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientBob.Close()
if err != nil {
t.Errorf("failed to close Bob client: %s", err)
}
}()
connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
_, err = connBobToAlice.Write(buf[:n])
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
n, err = connAliceToBob.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestBindToUnavailabePeer(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
log.Infof("closing server")
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn(ctx, "bob")
if err == nil {
t.Errorf("expected error when binding to unavailable peer, got nil")
}
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}
func TestBindReconnect(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
log.Infof("closing server")
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
err = clientBob.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
_, err = clientAlice.OpenConn(ctx, "bob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
chBob, err := clientBob.OpenConn(ctx, "alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing client Alice")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
clientAlice = NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
chAlice, err := clientAlice.OpenConn(ctx, "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
testString := "hello alice, I am bob"
_, err = chBob.Write([]byte(testString))
if err == nil {
t.Errorf("expected error when writing to channel, got nil")
}
chBob, err = clientBob.OpenConn(ctx, "alice")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
_, err = chBob.Write([]byte(testString))
if err != nil {
t.Errorf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := chAlice.Read(buf)
if err != nil {
t.Errorf("failed to read from channel: %s", err)
}
if testString != string(buf[:n]) {
t.Errorf("expected %s, got %s", testString, string(buf[:n]))
}
log.Infof("closing client")
err = clientAlice.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
}
func TestCloseConn(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
log.Infof("closing server")
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
bob := NewClient(serverURL, hmacTokenStore, "bob")
err = bob.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Errorf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn(ctx, "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
log.Infof("closing connection")
err = conn.Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = conn.Write([]byte("hello"))
if err == nil {
t.Errorf("unexpected writing from closed connection")
}
}
func TestCloseRelayConn(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
bob := NewClient(serverURL, hmacTokenStore, "bob")
err = bob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
conn, err := clientAlice.OpenConn(ctx, "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
_ = clientAlice.relayConn.Close()
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
_, err = clientAlice.OpenConn(ctx, "bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByServer(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv1, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
if err = relayClient.Connect(ctx); err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
defer func() {
if err := relayClient.Close(); err != nil {
log.Errorf("failed to close client: %s", err)
}
}()
disconnected := make(chan struct{})
relayClient.SetOnDisconnectListener(func(_ string) {
log.Infof("client disconnected")
close(disconnected)
})
err = srv1.Shutdown(ctx)
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
select {
case <-disconnected:
case <-time.After(3 * time.Second):
log.Errorf("timeout waiting for client to disconnect")
}
_, err = relayClient.OpenConn(ctx, "bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
}
func TestCloseByClient(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
log.Debugf("connect by alice")
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
err = relayClient.Connect(ctx)
if err != nil {
log.Fatalf("failed to connect to server: %s", err)
}
err = relayClient.Close()
if err != nil {
t.Errorf("failed to close client: %s", err)
}
_, err = relayClient.OpenConn(ctx, "bob")
if err == nil {
t.Errorf("unexpected opening connection to closed server")
}
err = srv.Shutdown(ctx)
if err != nil {
t.Fatalf("failed to close server: %s", err)
}
}
func TestCloseNotDrainedChannel(t *testing.T) {
ctx := context.Background()
idAlice := "alice"
idBob := "bob"
srvCfg := server.ListenerConfig{Address: serverListenAddr}
srv, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv.Listen(srvCfg)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
// wait for servers to start
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
err = clientAlice.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientAlice.Close()
if err != nil {
t.Errorf("failed to close Alice client: %s", err)
}
}()
clientBob := NewClient(serverURL, hmacTokenStore, idBob)
err = clientBob.Connect(ctx)
if err != nil {
t.Fatalf("failed to connect to server: %s", err)
}
defer func() {
err := clientBob.Close()
if err != nil {
t.Errorf("failed to close Bob client: %s", err)
}
}()
connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
// the internal channel buffer size is 2. So we should overflow it
for i := 0; i < 5; i++ {
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
}
// wait for delivery
time.Sleep(1 * time.Second)
err = connBobToAlice.Close()
if err != nil {
t.Errorf("failed to close channel: %s", err)
}
}
func waitForServerToStart(errChan chan error) error {
select {
case err := <-errChan:
if err != nil {
return err
}
case <-time.After(300 * time.Millisecond):
return nil
}
return nil
}

View File

@@ -0,0 +1,74 @@
package client
import (
"net"
"time"
"github.com/netbirdio/netbird/relay/messages"
)
// Conn represent a connection to a relayed remote peer.
type Conn struct {
client *Client
dstID messages.PeerID
messageChan chan Msg
instanceURL *RelayAddr
}
// NewConn creates a new connection to a relayed remote peer.
// client: the client instance, it used to send messages to the destination peer
// dstID: the destination peer ID
// messageChan: the channel where the messages will be received
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
func NewConn(client *Client, dstID messages.PeerID, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
c := &Conn{
client: client,
dstID: dstID,
messageChan: messageChan,
instanceURL: instanceURL,
}
return c
}
func (c *Conn) Write(p []byte) (n int, err error) {
return c.client.writeTo(c, c.dstID, p)
}
func (c *Conn) Read(b []byte) (n int, err error) {
msg, ok := <-c.messageChan
if !ok {
return 0, net.ErrClosed
}
n = copy(b, msg.Payload)
msg.Free()
return n, nil
}
func (c *Conn) Close() error {
return c.client.closeConn(c, c.dstID)
}
func (c *Conn) LocalAddr() net.Addr {
return c.client.relayConn.LocalAddr()
}
func (c *Conn) RemoteAddr() net.Addr {
return c.instanceURL
}
func (c *Conn) SetDeadline(t time.Time) error {
//TODO implement me
panic("SetDeadline is not implemented")
}
func (c *Conn) SetReadDeadline(t time.Time) error {
//TODO implement me
panic("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
//TODO implement me
panic("SetReadDeadline is not implemented")
}

View File

@@ -0,0 +1,7 @@
package net
import "errors"
var (
ErrClosedByServer = errors.New("closed by server")
)

View File

@@ -0,0 +1,97 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
)
const (
Network = "quic"
)
type Addr struct {
addr string
}
func (a Addr) Network() string {
return Network
}
func (a Addr) String() string {
return a.addr
}
type Conn struct {
session quic.Connection
ctx context.Context
}
func NewConn(session quic.Connection) net.Conn {
return &Conn{
session: session,
ctx: context.Background(),
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
dgram, err := c.session.ReceiveDatagram(c.ctx)
if err != nil {
return 0, c.remoteCloseErrHandling(err)
}
n = copy(b, dgram)
return n, nil
}
func (c *Conn) Write(b []byte) (int, error) {
err := c.session.SendDatagram(b)
if err != nil {
err = c.remoteCloseErrHandling(err)
log.Errorf("failed to write to QUIC stream: %v", err)
return 0, err
}
return len(b), nil
}
func (c *Conn) RemoteAddr() net.Addr {
return c.session.RemoteAddr()
}
func (c *Conn) LocalAddr() net.Addr {
if c.session != nil {
return c.session.LocalAddr()
}
return Addr{addr: "unknown"}
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return nil
}
func (c *Conn) Close() error {
return c.session.CloseWithError(0, "normal closure")
}
func (c *Conn) remoteCloseErrHandling(err error) error {
var appErr *quic.ApplicationError
if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 {
return netErr.ErrClosedByServer
}
return err
}

View File

@@ -0,0 +1,99 @@
package quic
import (
"context"
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/quic-go/quic-go"
log "github.com/sirupsen/logrus"
quictls "github.com/netbirdio/netbird/relay/tls"
nbnet "github.com/netbirdio/netbird/util/net"
)
type Dialer struct {
}
func (d Dialer) Protocol() string {
return Network
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
quicURL, err := prepareURL(address)
if err != nil {
return nil, err
}
// Get the base TLS config
tlsClientConfig := quictls.ClientQUICTLSConfig()
// Set ServerName to hostname if not an IP address
host, _, splitErr := net.SplitHostPort(quicURL)
if splitErr == nil && net.ParseIP(host) == nil {
// It's a hostname, not an IP - modify directly
tlsClientConfig.ServerName = host
}
quicConfig := &quic.Config{
KeepAlivePeriod: 30 * time.Second,
MaxIdleTimeout: 4 * time.Minute,
EnableDatagrams: true,
InitialPacketSize: 1452,
}
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
log.Errorf("failed to listen on UDP: %s", err)
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", quicURL)
if err != nil {
log.Errorf("failed to resolve UDP address: %s", err)
return nil, err
}
session, err := quic.Dial(ctx, udpConn, udpAddr, tlsClientConfig, quicConfig)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err)
return nil, err
}
conn := NewConn(session)
return conn, nil
}
func prepareURL(address string) (string, error) {
var host string
var defaultPort string
switch {
case strings.HasPrefix(address, "rels://"):
host = address[7:]
defaultPort = "443"
case strings.HasPrefix(address, "rel://"):
host = address[6:]
defaultPort = "80"
default:
return "", fmt.Errorf("unsupported scheme: %s", address)
}
finalHost, finalPort, err := net.SplitHostPort(host)
if err != nil {
if strings.Contains(err.Error(), "missing port") {
return host + ":" + defaultPort, nil
}
// return any other split error as is
return "", err
}
return finalHost + ":" + finalPort, nil
}

View File

@@ -0,0 +1,98 @@
package dialer
import (
"context"
"errors"
"net"
"time"
log "github.com/sirupsen/logrus"
)
const (
DefaultConnectionTimeout = 30 * time.Second
)
type DialeFn interface {
Dial(ctx context.Context, address string) (net.Conn, error)
Protocol() string
}
type dialResult struct {
Conn net.Conn
Protocol string
Err error
}
type RaceDial struct {
log *log.Entry
serverURL string
dialerFns []DialeFn
connectionTimeout time.Duration
}
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
return &RaceDial{
log: log,
serverURL: serverURL,
dialerFns: dialerFns,
connectionTimeout: connectionTimeout,
}
}
func (r *RaceDial) Dial() (net.Conn, error) {
connChan := make(chan dialResult, len(r.dialerFns))
winnerConn := make(chan net.Conn, 1)
abortCtx, abort := context.WithCancel(context.Background())
defer abort()
for _, dfn := range r.dialerFns {
go r.dial(dfn, abortCtx, connChan)
}
go r.processResults(connChan, winnerConn, abort)
conn, ok := <-winnerConn
if !ok {
return nil, errors.New("failed to dial to Relay server on any protocol")
}
return conn, nil
}
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
defer cancel()
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
conn, err := dfn.Dial(ctx, r.serverURL)
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
}
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) {
var hasWinner bool
for i := 0; i < len(r.dialerFns); i++ {
dr := <-connChan
if dr.Err != nil {
if errors.Is(dr.Err, context.Canceled) {
r.log.Infof("connection attempt aborted via: %s", dr.Protocol)
} else {
r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err)
}
continue
}
if hasWinner {
if cerr := dr.Conn.Close(); cerr != nil {
r.log.Warnf("failed to close connection via %s: %s", dr.Protocol, cerr)
}
continue
}
r.log.Infof("successfully dialed via: %s", dr.Protocol)
abort()
hasWinner = true
winnerConn <- dr.Conn
}
close(winnerConn)
}

View File

@@ -0,0 +1,252 @@
package dialer
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/sirupsen/logrus"
)
type MockAddr struct {
network string
}
func (m *MockAddr) Network() string {
return m.network
}
func (m *MockAddr) String() string {
return "1.2.3.4"
}
// MockDialer is a mock implementation of DialeFn
type MockDialer struct {
dialFunc func(ctx context.Context, address string) (net.Conn, error)
protocolStr string
}
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
return m.dialFunc(ctx, address)
}
func (m *MockDialer) Protocol() string {
return m.protocolStr
}
// MockConn implements net.Conn for testing
type MockConn struct {
remoteAddr net.Addr
}
func (m *MockConn) Read(b []byte) (n int, err error) {
return 0, nil
}
func (m *MockConn) Write(b []byte) (n int, err error) {
return 0, nil
}
func (m *MockConn) Close() error {
return nil
}
func (m *MockConn) LocalAddr() net.Addr {
return nil
}
func (m *MockConn) RemoteAddr() net.Addr {
return m.remoteAddr
}
func (m *MockConn) SetDeadline(t time.Time) error {
return nil
}
func (m *MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestRaceDialEmptyDialers(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error with empty dialers, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection with empty dialers, got %v", conn)
}
}
func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto := "test-protocol"
mockConn := &MockConn{
remoteAddr: &MockAddr{network: proto},
}
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return mockConn, nil
},
protocolStr: proto,
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn == nil {
t.Errorf("Expected non-nil connection")
}
}
func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto2 := "protocol2"
mockConn2 := &MockConn{
remoteAddr: &MockAddr{network: proto2},
}
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("first dialer failed")
},
protocolStr: "proto1",
}
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return mockConn2, nil
},
protocolStr: "proto2",
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn.RemoteAddr().Network() != proto2 {
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
}
_ = conn.Close()
}
func TestRaceDialTimeout(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
mockDialer := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
return nil, ctx.Err()
},
protocolStr: "proto1",
}
rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection, got %v", conn)
}
}
func TestRaceDialAllDialersFail(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("first dialer failed")
},
protocolStr: "protocol1",
}
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
return nil, errors.New("second dialer failed")
},
protocolStr: "protocol2",
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err == nil {
t.Errorf("Expected an error, got nil")
}
if conn != nil {
t.Errorf("Expected nil connection, got %v", conn)
}
}
func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
logger := logrus.NewEntry(logrus.New())
serverURL := "test.server.com"
proto1 := "protocol1"
proto2 := "protocol2"
mockConn1 := &MockConn{
remoteAddr: &MockAddr{network: proto1},
}
mockConn2 := &MockConn{
remoteAddr: &MockAddr{network: proto2},
}
mockDialer1 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
time.Sleep(1 * time.Second)
return mockConn1, nil
},
protocolStr: proto1,
}
mock2err := make(chan error)
mockDialer2 := &MockDialer{
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
<-ctx.Done()
mock2err <- ctx.Err()
return mockConn2, ctx.Err()
},
protocolStr: proto2,
}
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
conn, err := rd.Dial()
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if conn == nil {
t.Errorf("Expected non-nil connection")
}
if conn != mockConn1 {
t.Errorf("Expected first connection, got %v", conn)
}
select {
case <-time.After(3 * time.Second):
t.Errorf("Timed out waiting for second dialer to finish")
case err := <-mock2err:
if !errors.Is(err, context.Canceled) {
t.Errorf("Expected context.Canceled error, got %v", err)
}
}
}

View File

@@ -0,0 +1,17 @@
package ws
const (
Network = "ws"
)
type WebsocketAddr struct {
addr string
}
func (a WebsocketAddr) Network() string {
return Network
}
func (a WebsocketAddr) String() string {
return a.addr
}

View File

@@ -0,0 +1,67 @@
package ws
import (
"context"
"fmt"
"net"
"time"
"github.com/coder/websocket"
)
type Conn struct {
ctx context.Context
*websocket.Conn
remoteAddr WebsocketAddr
}
func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
return &Conn{
ctx: context.Background(),
Conn: wsConn,
remoteAddr: WebsocketAddr{serverAddress},
}
}
func (c *Conn) Read(b []byte) (n int, err error) {
t, ioReader, err := c.Conn.Reader(c.ctx)
if err != nil {
// todo use ErrClosedByServer
return 0, err
}
if t != websocket.MessageBinary {
return 0, fmt.Errorf("unexpected message type")
}
return ioReader.Read(b)
}
func (c *Conn) Write(b []byte) (n int, err error) {
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
return 0, err
}
func (c *Conn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *Conn) LocalAddr() net.Addr {
return WebsocketAddr{addr: "unknown"}
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return fmt.Errorf("SetReadDeadline is not implemented")
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline is not implemented")
}
func (c *Conn) SetDeadline(t time.Time) error {
return fmt.Errorf("SetDeadline is not implemented")
}
func (c *Conn) Close() error {
return c.Conn.CloseNow()
}

View File

@@ -0,0 +1,90 @@
package ws
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"github.com/coder/websocket"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/server/listener/ws"
"github.com/netbirdio/netbird/util/embeddedroots"
nbnet "github.com/netbirdio/netbird/util/net"
)
type Dialer struct {
}
func (d Dialer) Protocol() string {
return "WS"
}
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
wsURL, err := prepareURL(address)
if err != nil {
return nil, err
}
opts := &websocket.DialOptions{
HTTPClient: httpClientNbDialer(),
}
parsedURL, err := url.Parse(wsURL)
if err != nil {
return nil, err
}
parsedURL.Path = ws.URLPath
wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts)
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, err
}
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
return nil, err
}
if resp.Body != nil {
_ = resp.Body.Close()
}
conn := NewConn(wsConn, address)
return conn, nil
}
func prepareURL(address string) (string, error) {
if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") {
return "", fmt.Errorf("unsupported scheme: %s", address)
}
return strings.Replace(address, "rel", "ws", 1), nil
}
func httpClientNbDialer() *http.Client {
customDialer := nbnet.NewDialer()
certPool, err := x509.SystemCertPool()
if err != nil || certPool == nil {
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
certPool = embeddedroots.Get()
}
customTransport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return customDialer.DialContext(ctx, network, addr)
},
TLSClientConfig: &tls.Config{
RootCAs: certPool,
},
}
return &http.Client{
Transport: customTransport,
}
}

View File

@@ -0,0 +1,12 @@
/*
Package client contains the implementation of the Relay client.
The Relay client is responsible for establishing a connection with the Relay server and sending and receiving messages,
Keep persistent connection with the Relay server and handle the connection issues.
It uses the WebSocket protocol for communication and optionally supports TLS (Transport Layer Security).
If a peer wants to communicate with a peer on a different relay server, the manager will establish a new connection to
the relay server. The connection with these relay servers will be closed if there is no active connection. The peers
negotiate the common relay instance via signaling service.
*/
package client

View File

@@ -0,0 +1,10 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,149 @@
package client
import (
"context"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
)
const (
// TODO: make it configurable, the manager should validate all configurable parameters
reconnectingTimeout = 60 * time.Second
)
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
type Guard struct {
// OnNewRelayClient is a channel that is used to notify the relay manager about a new relay client instance.
OnNewRelayClient chan *Client
OnReconnected chan struct{}
serverPicker *ServerPicker
}
// NewGuard creates a new guard for the relay client.
func NewGuard(sp *ServerPicker) *Guard {
g := &Guard{
OnNewRelayClient: make(chan *Client, 1),
OnReconnected: make(chan struct{}, 1),
serverPicker: sp,
}
return g
}
// StartReconnectTrys is called when the relay client is disconnected from the relay server.
// It attempts to reconnect to the relay server. The function first tries a quick reconnect
// to the same server that was used before, if the server URL is still valid. If the quick
// reconnect fails, it starts a ticker to periodically attempt server picking until it
// succeeds or the context is done.
//
// Parameters:
// - ctx: The context to control the lifecycle of the reconnection attempts.
// - relayClient: The relay client instance that was disconnected.
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
// try to reconnect to the same server
if ok := g.tryToQuickReconnect(ctx, relayClient); ok {
g.notifyReconnected()
return
}
// start a ticker to pick a new server
ticker := exponentTicker(ctx)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := g.retry(ctx); err != nil {
log.Errorf("failed to pick new Relay server: %s", err)
continue
}
return
case <-ctx.Done():
return
}
}
}
func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool {
if rc == nil {
return false
}
if !g.isServerURLStillValid(rc) {
return false
}
if cancelled := waiteBeforeRetry(parentCtx); !cancelled {
return false
}
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
if err := rc.Connect(parentCtx); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
return false
}
return true
}
func (g *Guard) retry(ctx context.Context) error {
log.Infof("try to pick up a new Relay server")
relayClient, err := g.serverPicker.PickServer(ctx)
if err != nil {
return err
}
// prevent to work with a deprecated Relay client instance
g.drainRelayClientChan()
g.OnNewRelayClient <- relayClient
return nil
}
func (g *Guard) drainRelayClientChan() {
select {
case <-g.OnNewRelayClient:
default:
}
}
func (g *Guard) isServerURLStillValid(rc *Client) bool {
for _, url := range g.serverPicker.ServerURLs.Load().([]string) {
if url == rc.connectionURL {
return true
}
}
return false
}
func (g *Guard) notifyReconnected() {
select {
case g.OnReconnected <- struct{}{}:
default:
}
}
func exponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 2 * time.Second,
Multiplier: 2,
MaxInterval: reconnectingTimeout,
Clock: backoff.SystemClock,
}, ctx)
return backoff.NewTicker(bo)
}
func waiteBeforeRetry(ctx context.Context) bool {
timer := time.NewTimer(1500 * time.Millisecond)
defer timer.Stop()
select {
case <-timer.C:
return true
case <-ctx.Done():
return false
}
}

View File

@@ -0,0 +1,404 @@
package client
import (
"container/list"
"context"
"fmt"
"net"
"reflect"
"sync"
"time"
log "github.com/sirupsen/logrus"
relayAuth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
)
var (
relayCleanupInterval = 60 * time.Second
keepUnusedServerTime = 5 * time.Second
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
)
// RelayTrack hold the relay clients for the foreign relay servers.
// With the mutex can ensure we can open new connection in case the relay connection has been established with
// the relay server.
type RelayTrack struct {
sync.RWMutex
relayClient *Client
err error
created time.Time
}
func NewRelayTrack() *RelayTrack {
return &RelayTrack{
created: time.Now(),
}
}
type OnServerCloseListener func()
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
// and automatically reconnect to them in case disconnection.
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
// different relay server, the manager will establish a new connection to the relay server. The connection with these
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
// unused relay connection and close it.
type Manager struct {
ctx context.Context
peerID string
running bool
tokenStore *relayAuth.TokenStore
serverPicker *ServerPicker
relayClient *Client
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
relayClientMu sync.RWMutex
reconnectGuard *Guard
relayClients map[string]*RelayTrack
relayClientsMutex sync.RWMutex
onDisconnectedListeners map[string]*list.List
onReconnectedListenerFn func()
listenerLock sync.Mutex
}
// NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
tokenStore := &relayAuth.TokenStore{}
m := &Manager{
ctx: ctx,
peerID: peerID,
tokenStore: tokenStore,
serverPicker: &ServerPicker{
TokenStore: tokenStore,
PeerID: peerID,
},
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
}
m.serverPicker.ServerURLs.Store(serverURLs)
m.reconnectGuard = NewGuard(m.serverPicker)
return m
}
// Serve starts the manager, attempting to establish a connection with the relay server.
// If the connection fails, it will keep trying to reconnect in the background.
// Additionally, it starts a cleanup loop to remove unused relay connections.
// The manager will automatically reconnect to the relay server in case of disconnection.
func (m *Manager) Serve() error {
if m.running {
return fmt.Errorf("manager already serving")
}
m.running = true
log.Debugf("starting relay client manager with %v relay servers", m.serverPicker.ServerURLs.Load())
client, err := m.serverPicker.PickServer(m.ctx)
if err != nil {
go m.reconnectGuard.StartReconnectTrys(m.ctx, nil)
} else {
m.storeClient(client)
}
go m.listenGuardEvent(m.ctx)
go m.startCleanupLoop()
return err
}
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
m.relayClientMu.RLock()
defer m.relayClientMu.RUnlock()
if m.relayClient == nil {
return nil, ErrRelayClientNotConnected
}
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return nil, err
}
var (
netConn net.Conn
)
if !foreign {
log.Debugf("open peer connection via permanent server: %s", peerKey)
netConn, err = m.relayClient.OpenConn(ctx, peerKey)
} else {
log.Debugf("open peer connection via foreign server: %s", serverAddress)
netConn, err = m.openConnVia(ctx, serverAddress, peerKey)
}
if err != nil {
return nil, err
}
return netConn, err
}
// Ready returns true if the home Relay client is connected to the relay server.
func (m *Manager) Ready() bool {
m.relayClientMu.RLock()
defer m.relayClientMu.RUnlock()
if m.relayClient == nil {
return false
}
return m.relayClient.Ready()
}
func (m *Manager) SetOnReconnectedListener(f func()) {
m.listenerLock.Lock()
defer m.listenerLock.Unlock()
m.onReconnectedListenerFn = f
}
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
// closed.
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
m.relayClientMu.RLock()
defer m.relayClientMu.RUnlock()
if m.relayClient == nil {
return ErrRelayClientNotConnected
}
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return err
}
var listenerAddr string
if foreign {
listenerAddr = serverAddress
} else {
listenerAddr = m.relayClient.connectionURL
}
m.addListener(listenerAddr, onClosedListener)
return nil
}
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayInstanceAddress() (string, error) {
m.relayClientMu.RLock()
defer m.relayClientMu.RUnlock()
if m.relayClient == nil {
return "", ErrRelayClientNotConnected
}
return m.relayClient.ServerInstanceURL()
}
// ServerURLs returns the addresses of the relay servers.
func (m *Manager) ServerURLs() []string {
return m.serverPicker.ServerURLs.Load().([]string)
}
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
// Relay service.
func (m *Manager) HasRelayAddress() bool {
return len(m.serverPicker.ServerURLs.Load().([]string)) > 0
}
func (m *Manager) UpdateServerURLs(serverURLs []string) {
log.Infof("update relay server URLs: %v", serverURLs)
m.serverPicker.ServerURLs.Store(serverURLs)
}
// UpdateToken updates the token in the token store.
func (m *Manager) UpdateToken(token *relayAuth.Token) error {
return m.tokenStore.UpdateToken(token)
}
func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
// check if already has a connection to the desired relay server
m.relayClientsMutex.RLock()
rt, ok := m.relayClients[serverAddress]
if ok {
rt.RLock()
m.relayClientsMutex.RUnlock()
defer rt.RUnlock()
if rt.err != nil {
return nil, rt.err
}
return rt.relayClient.OpenConn(ctx, peerKey)
}
m.relayClientsMutex.RUnlock()
// if not, establish a new connection but check it again (because changed the lock type) before starting the
// connection
m.relayClientsMutex.Lock()
rt, ok = m.relayClients[serverAddress]
if ok {
rt.RLock()
m.relayClientsMutex.Unlock()
defer rt.RUnlock()
if rt.err != nil {
return nil, rt.err
}
return rt.relayClient.OpenConn(ctx, peerKey)
}
// create a new relay client and store it in the relayClients map
rt = NewRelayTrack()
rt.Lock()
m.relayClients[serverAddress] = rt
m.relayClientsMutex.Unlock()
relayClient := NewClient(serverAddress, m.tokenStore, m.peerID)
err := relayClient.Connect(m.ctx)
if err != nil {
rt.err = err
rt.Unlock()
m.relayClientsMutex.Lock()
delete(m.relayClients, serverAddress)
m.relayClientsMutex.Unlock()
return nil, err
}
// if connection closed then delete the relay client from the list
relayClient.SetOnDisconnectListener(m.onServerDisconnected)
rt.relayClient = relayClient
rt.Unlock()
conn, err := relayClient.OpenConn(ctx, peerKey)
if err != nil {
return nil, err
}
return conn, nil
}
func (m *Manager) onServerConnected() {
m.listenerLock.Lock()
defer m.listenerLock.Unlock()
if m.onReconnectedListenerFn == nil {
return
}
go m.onReconnectedListenerFn()
}
// onServerDisconnected start to reconnection for home server only
func (m *Manager) onServerDisconnected(serverAddress string) {
m.relayClientMu.Lock()
if serverAddress == m.relayClient.connectionURL {
go func(client *Client) {
m.reconnectGuard.StartReconnectTrys(m.ctx, client)
}(m.relayClient)
}
m.relayClientMu.Unlock()
m.notifyOnDisconnectListeners(serverAddress)
}
func (m *Manager) listenGuardEvent(ctx context.Context) {
for {
select {
case <-m.reconnectGuard.OnReconnected:
m.onServerConnected()
case rc := <-m.reconnectGuard.OnNewRelayClient:
m.storeClient(rc)
m.onServerConnected()
case <-ctx.Done():
return
}
}
}
func (m *Manager) storeClient(client *Client) {
m.relayClientMu.Lock()
defer m.relayClientMu.Unlock()
m.relayClient = client
m.relayClient.SetOnDisconnectListener(m.onServerDisconnected)
}
func (m *Manager) isForeignServer(address string) (bool, error) {
rAddr, err := m.relayClient.ServerInstanceURL()
if err != nil {
return false, fmt.Errorf("relay client not connected")
}
return rAddr != address, nil
}
func (m *Manager) startCleanupLoop() {
ticker := time.NewTicker(relayCleanupInterval)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
m.cleanUpUnusedRelays()
}
}
}
func (m *Manager) cleanUpUnusedRelays() {
m.relayClientsMutex.Lock()
defer m.relayClientsMutex.Unlock()
for addr, rt := range m.relayClients {
rt.Lock()
// if the connection failed to the server the relay client will be nil
// but the instance will be kept in the relayClients until the next locking
if rt.err != nil {
rt.Unlock()
continue
}
if time.Since(rt.created) <= keepUnusedServerTime {
rt.Unlock()
continue
}
if rt.relayClient.HasConns() {
rt.Unlock()
continue
}
rt.relayClient.SetOnDisconnectListener(nil)
go func() {
_ = rt.relayClient.Close()
}()
log.Debugf("clean up unused relay server connection: %s", addr)
delete(m.relayClients, addr)
rt.Unlock()
}
}
func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
m.listenerLock.Lock()
defer m.listenerLock.Unlock()
l, ok := m.onDisconnectedListeners[serverAddress]
if !ok {
l = list.New()
}
for e := l.Front(); e != nil; e = e.Next() {
if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() {
return
}
}
l.PushBack(onClosedListener)
m.onDisconnectedListeners[serverAddress] = l
}
func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
m.listenerLock.Lock()
defer m.listenerLock.Unlock()
l, ok := m.onDisconnectedListeners[serverAddress]
if !ok {
return
}
for e := l.Front(); e != nil; e = e.Next() {
go e.Value.(OnServerCloseListener)()
}
delete(m.onDisconnectedListeners, serverAddress)
}

View File

@@ -0,0 +1,469 @@
package client
import (
"context"
"testing"
"time"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/shared/relay/auth/allow"
"github.com/netbirdio/netbird/relay/server"
)
func TestEmptyURL(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
mgr := NewManager(ctx, nil, "alice")
err := mgr.Serve()
if err == nil {
t.Errorf("expected error, got nil")
}
}
func TestForeignConn(t *testing.T) {
ctx := context.Background()
lstCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: lstCfg1.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(lstCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv1.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
srvCfg2 := server.ListenerConfig{
Address: "localhost:2234",
}
srv2, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: srvCfg2.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan2 := make(chan error, 1)
go func() {
err := srv2.Listen(srvCfg2)
if err != nil {
errChan2 <- err
}
}()
defer func() {
err := srv2.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan2); err != nil {
t.Fatalf("failed to start server: %s", err)
}
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice")
if err := clientAlice.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
clientBob := NewManager(mCtx, toURL(srvCfg2), "bob")
if err := clientBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
bobsSrvAddr, err := clientBob.RelayInstanceAddress()
if err != nil {
t.Fatalf("failed to get relay address: %s", err)
}
connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
payload := "hello bob, I am alice"
_, err = connAliceToBob.Write([]byte(payload))
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
buf := make([]byte, 65535)
n, err := connBobToAlice.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
_, err = connBobToAlice.Write(buf[:n])
if err != nil {
t.Fatalf("failed to write to channel: %s", err)
}
n, err = connAliceToBob.Read(buf)
if err != nil {
t.Fatalf("failed to read from channel: %s", err)
}
if payload != string(buf[:n]) {
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
}
}
func TestForeginConnClose(t *testing.T) {
ctx := context.Background()
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
err := srv1.Listen(srvCfg1)
if err != nil {
errChan <- err
}
}()
defer func() {
err := srv1.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
srvCfg2 := server.ListenerConfig{
Address: "localhost:2234",
}
srv2, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan2 := make(chan error, 1)
go func() {
err := srv2.Listen(srvCfg2)
if err != nil {
errChan2 <- err
}
}()
defer func() {
err := srv2.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan2); err != nil {
t.Fatalf("failed to start server: %s", err)
}
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob")
if err := mgrBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
mgr := NewManager(mCtx, toURL(srvCfg1), "alice")
err = mgr.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
err = conn.Close()
if err != nil {
t.Fatalf("failed to close connection: %s", err)
}
}
func TestForeignAutoClose(t *testing.T) {
ctx := context.Background()
relayCleanupInterval = 1 * time.Second
keepUnusedServerTime = 2 * time.Second
srvCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv1, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
t.Log("binding server 1.")
if err := srv1.Listen(srvCfg1); err != nil {
errChan <- err
}
}()
defer func() {
t.Logf("closing server 1.")
if err := srv1.Shutdown(ctx); err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 1. closed")
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
srvCfg2 := server.ListenerConfig{
Address: "localhost:2234",
}
srv2, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan2 := make(chan error, 1)
go func() {
t.Log("binding server 2.")
err := srv2.Listen(srvCfg2)
if err != nil {
errChan2 <- err
}
}()
defer func() {
t.Logf("closing server 2.")
err := srv2.Shutdown(ctx)
if err != nil {
t.Errorf("failed to close server: %s", err)
}
t.Logf("server 2 closed.")
}()
if err := waitForServerToStart(errChan2); err != nil {
t.Fatalf("failed to start server: %s", err)
}
idAlice := "alice"
t.Log("connect to server 1.")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
err = mgr.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
// Set up a disconnect listener to track when foreign server disconnects
foreignServerURL := toURL(srvCfg2)[0]
disconnected := make(chan struct{})
onDisconnect := func() {
select {
case disconnected <- struct{}{}:
default:
}
}
t.Log("open connection to another peer")
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
t.Fatalf("should have failed to open connection to another peer")
}
// Add the disconnect listener after the connection attempt
if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil {
t.Logf("failed to add close listener (expected if connection failed): %s", err)
}
// Wait for cleanup to happen
timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second
t.Logf("waiting for relay cleanup: %s", timeout)
select {
case <-disconnected:
t.Log("foreign relay connection cleaned up successfully")
case <-time.After(timeout):
t.Log("timeout waiting for cleanup - this might be expected if connection never established")
}
t.Logf("closing manager")
}
func TestAutoReconnect(t *testing.T) {
ctx := context.Background()
srvCfg := server.ListenerConfig{
Address: "localhost:1234",
}
srv, err := server.NewServer(serverCfg)
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
if err := srv.Listen(srvCfg); err != nil {
errChan <- err
}
}()
defer func() {
err := srv.Shutdown(ctx)
if err != nil {
log.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientBob := NewManager(mCtx, toURL(srvCfg), "bob")
err = clientBob.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
err = clientAlice.Serve()
if err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
ra, err := clientAlice.RelayInstanceAddress()
if err != nil {
t.Errorf("failed to get relay address: %s", err)
}
conn, err := clientAlice.OpenConn(ctx, ra, "bob")
if err != nil {
t.Errorf("failed to bind channel: %s", err)
}
t.Log("closing client relay connection")
// todo figure out moc server
_ = clientAlice.relayClient.relayConn.Close()
t.Log("start test reading")
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Errorf("unexpected reading from closed connection")
}
log.Infof("waiting for reconnection")
time.Sleep(reconnectingTimeout + 1*time.Second)
log.Infof("reopent the connection")
_, err = clientAlice.OpenConn(ctx, ra, "bob")
if err != nil {
t.Errorf("failed to open channel: %s", err)
}
}
func TestNotifierDoubleAdd(t *testing.T) {
ctx := context.Background()
listenerCfg1 := server.ListenerConfig{
Address: "localhost:1234",
}
srv, err := server.NewServer(server.Config{
Meter: otel.Meter(""),
ExposedAddress: listenerCfg1.Address,
TLSSupport: false,
AuthValidator: &allow.Auth{},
})
if err != nil {
t.Fatalf("failed to create server: %s", err)
}
errChan := make(chan error, 1)
go func() {
if err := srv.Listen(listenerCfg1); err != nil {
errChan <- err
}
}()
defer func() {
if err := srv.Shutdown(ctx); err != nil {
t.Errorf("failed to close server: %s", err)
}
}()
if err := waitForServerToStart(errChan); err != nil {
t.Fatalf("failed to start server: %s", err)
}
log.Debugf("connect by alice")
mCtx, cancel := context.WithCancel(ctx)
defer cancel()
clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob")
if err = clientBob.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice")
if err = clientAlice.Serve(); err != nil {
t.Fatalf("failed to serve manager: %s", err)
}
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob")
if err != nil {
t.Fatalf("failed to bind channel: %s", err)
}
fnCloseListener := OnServerCloseListener(func() {
log.Infof("close listener")
})
err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
if err != nil {
t.Fatalf("failed to add close listener: %s", err)
}
err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
if err != nil {
t.Fatalf("failed to add close listener: %s", err)
}
err = conn1.Close()
if err != nil {
t.Errorf("failed to close connection: %s", err)
}
}
func toURL(address server.ListenerConfig) []string {
return []string{"rel://" + address.Address}
}

View File

@@ -0,0 +1,191 @@
package client
import (
"context"
"errors"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/relay/messages"
)
const (
OpenConnectionTimeout = 30 * time.Second
)
type relayedConnWriter interface {
Write(p []byte) (n int, err error)
}
// PeersStateSubscription manages subscriptions to peer state changes (online/offline)
// over a relay connection. It allows tracking peers' availability and handling offline
// events via a callback. We get online notification from the server only once.
type PeersStateSubscription struct {
log *log.Entry
relayConn relayedConnWriter
offlineCallback func(peerIDs []messages.PeerID)
listenForOfflinePeers map[messages.PeerID]struct{}
waitingPeers map[messages.PeerID]chan struct{}
mu sync.Mutex // Mutex to protect access to waitingPeers and listenForOfflinePeers
}
func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription {
return &PeersStateSubscription{
log: log,
relayConn: relayConn,
offlineCallback: offlineCallback,
listenForOfflinePeers: make(map[messages.PeerID]struct{}),
waitingPeers: make(map[messages.PeerID]chan struct{}),
}
}
// OnPeersOnline should be called when a notification is received that certain peers have come online.
// It checks if any of the peers are being waited on and signals their availability.
func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) {
s.mu.Lock()
defer s.mu.Unlock()
for _, peerID := range peersID {
waitCh, ok := s.waitingPeers[peerID]
if !ok {
// If meanwhile the peer was unsubscribed, we don't need to signal it
continue
}
waitCh <- struct{}{}
delete(s.waitingPeers, peerID)
close(waitCh)
}
}
func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
s.mu.Lock()
relevantPeers := make([]messages.PeerID, 0, len(peersID))
for _, peerID := range peersID {
if _, ok := s.listenForOfflinePeers[peerID]; ok {
relevantPeers = append(relevantPeers, peerID)
}
}
s.mu.Unlock()
if len(relevantPeers) > 0 {
s.offlineCallback(relevantPeers)
}
}
// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes.
func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error {
// Check if already waiting for this peer
s.mu.Lock()
if _, exists := s.waitingPeers[peerID]; exists {
s.mu.Unlock()
return errors.New("already waiting for peer to come online")
}
// Create a channel to wait for the peer to come online
waitCh := make(chan struct{}, 1)
s.waitingPeers[peerID] = waitCh
s.listenForOfflinePeers[peerID] = struct{}{}
s.mu.Unlock()
if err := s.subscribeStateChange(peerID); err != nil {
s.log.Errorf("failed to subscribe to peer state: %s", err)
s.mu.Lock()
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh)
delete(s.waitingPeers, peerID)
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return err
}
// Wait for peer to come online or context to be cancelled
timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout)
defer cancel()
select {
case _, ok := <-waitCh:
if !ok {
return fmt.Errorf("wait for peer to come online has been cancelled")
}
s.log.Debugf("peer %s is now online", peerID)
return nil
case <-timeoutCtx.Done():
s.log.Debugf("context timed out while waiting for peer %s to come online", peerID)
if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil {
s.log.Errorf("failed to unsubscribe from peer state: %s", err)
}
s.mu.Lock()
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
close(waitCh)
delete(s.waitingPeers, peerID)
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return timeoutCtx.Err()
}
}
func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error {
msgErr := s.unsubscribeStateChange(peerIDs)
s.mu.Lock()
for _, peerID := range peerIDs {
if wch, ok := s.waitingPeers[peerID]; ok {
close(wch)
delete(s.waitingPeers, peerID)
}
delete(s.listenForOfflinePeers, peerID)
}
s.mu.Unlock()
return msgErr
}
func (s *PeersStateSubscription) Cleanup() {
s.mu.Lock()
defer s.mu.Unlock()
for _, waitCh := range s.waitingPeers {
close(waitCh)
}
s.waitingPeers = make(map[messages.PeerID]chan struct{})
s.listenForOfflinePeers = make(map[messages.PeerID]struct{})
}
func (s *PeersStateSubscription) subscribeStateChange(peerID messages.PeerID) error {
msgs, err := messages.MarshalSubPeerStateMsg([]messages.PeerID{peerID})
if err != nil {
return err
}
for _, msg := range msgs {
if _, err := s.relayConn.Write(msg); err != nil {
return err
}
}
return nil
}
func (s *PeersStateSubscription) unsubscribeStateChange(peerIDs []messages.PeerID) error {
msgs, err := messages.MarshalUnsubPeerStateMsg(peerIDs)
if err != nil {
return err
}
var connWriteErr error
for _, msg := range msgs {
if _, err := s.relayConn.Write(msg); err != nil {
connWriteErr = err
}
}
return connWriteErr
}

View File

@@ -0,0 +1,99 @@
package client
import (
"bytes"
"context"
"testing"
"time"
"github.com/netbirdio/netbird/relay/messages"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockRelayedConn struct {
}
func (m *mockRelayedConn) Write(p []byte) (n int, err error) {
return len(p), nil
}
func TestWaitToBeOnlineAndSubscribe_Success(t *testing.T) {
peerID := messages.HashID("peer1")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{}) // discard log output
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Launch wait in background
go func() {
time.Sleep(100 * time.Millisecond)
sub.OnPeersOnline([]messages.PeerID{peerID})
}()
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
assert.NoError(t, err)
}
func TestWaitToBeOnlineAndSubscribe_Timeout(t *testing.T) {
peerID := messages.HashID("peer2")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
assert.Error(t, err)
assert.Equal(t, context.DeadlineExceeded, err)
}
func TestWaitToBeOnlineAndSubscribe_Duplicate(t *testing.T) {
peerID := messages.HashID("peer3")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
ctx := context.Background()
go func() {
_ = sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
}()
time.Sleep(100 * time.Millisecond)
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "already waiting")
}
func TestUnsubscribeStateChange(t *testing.T) {
peerID := messages.HashID("peer4")
mockConn := &mockRelayedConn{}
logger := logrus.New()
logger.SetOutput(&bytes.Buffer{})
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
doneChan := make(chan struct{})
go func() {
_ = sub.WaitToBeOnlineAndSubscribe(context.Background(), peerID)
close(doneChan)
}()
time.Sleep(100 * time.Millisecond)
err := sub.UnsubscribeStateChange([]messages.PeerID{peerID})
assert.NoError(t, err)
select {
case <-doneChan:
case <-time.After(200 * time.Millisecond):
// Expected timeout, meaning the subscription was successfully unsubscribed
t.Errorf("timeout")
}
}

View File

@@ -0,0 +1,104 @@
package client
import (
"context"
"errors"
"fmt"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
)
const (
maxConcurrentServers = 7
)
var (
connectionTimeout = 30 * time.Second
)
type connResult struct {
RelayClient *Client
Url string
Err error
}
type ServerPicker struct {
TokenStore *auth.TokenStore
ServerURLs atomic.Value
PeerID string
}
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
defer cancel()
totalServers := len(sp.ServerURLs.Load().([]string))
connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string))
for _, url := range sp.ServerURLs.Load().([]string) {
// todo check if we have a successful connection so we do not need to connect to other servers
concurrentLimiter <- struct{}{}
go func(url string) {
defer func() {
<-concurrentLimiter
}()
sp.startConnection(parentCtx, connResultChan, url)
}(url)
}
go sp.processConnResults(connResultChan, successChan)
select {
case cr, ok := <-successChan:
if !ok {
return nil, errors.New("failed to connect to any relay server: all attempts failed")
}
log.Infof("chosen home Relay server: %s", cr.Url)
return cr.RelayClient, nil
case <-ctx.Done():
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
}
}
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
log.Infof("try to connecting to relay server: %s", url)
relayClient := NewClient(url, sp.TokenStore, sp.PeerID)
err := relayClient.Connect(ctx)
resultChan <- connResult{
RelayClient: relayClient,
Url: url,
Err: err,
}
}
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
var hasSuccess bool
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
cr := <-resultChan
if cr.Err != nil {
log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
continue
}
log.Infof("connected to Relay server: %s", cr.Url)
if hasSuccess {
log.Infof("closing unnecessary Relay connection to: %s", cr.Url)
if err := cr.RelayClient.Close(); err != nil {
log.Errorf("failed to close connection to %s: %v", cr.Url, err)
}
continue
}
hasSuccess = true
successChan <- cr
}
close(successChan)
}

View File

@@ -0,0 +1,34 @@
package client
import (
"context"
"errors"
"testing"
"time"
)
func TestServerPicker_UnavailableServers(t *testing.T) {
connectionTimeout = 5 * time.Second
sp := ServerPicker{
TokenStore: nil,
PeerID: "test",
}
sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
defer cancel()
go func() {
_, err := sp.PickServer(ctx)
if err == nil {
t.Error(err)
}
cancel()
}()
<-ctx.Done()
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Errorf("PickServer() took too long to complete")
}
}

View File

@@ -0,0 +1,77 @@
package client
import (
"context"
"fmt"
"io"
"strings"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/version"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer.
// Status is the status of the client
type Status string
const StreamConnected Status = "Connected"
const StreamDisconnected Status = "Disconnected"
const (
// DirectCheck indicates support to direct mode checks
DirectCheck uint32 = 1
)
type Client interface {
io.Closer
StreamConnected() bool
GetStatus() Status
Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error
Ready() bool
IsHealthy() bool
WaitStreamConnected()
SendToStream(msg *proto.EncryptedMessage) error
Send(msg *proto.Message) error
SetOnReconnectedListener(func())
}
// UnMarshalCredential parses the credentials from the message and returns a Credential instance
func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
credential := strings.Split(msg.GetBody().GetPayload(), ":")
if len(credential) != 2 {
return nil, fmt.Errorf("error parsing message body %s", msg.Body)
}
return &Credential{
UFrag: credential[0],
Pwd: credential[1],
}, nil
}
// MarshalCredential marshal a Credential instance and returns a Message object
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string) (*proto.Message, error) {
return &proto.Message{
Key: myKey.PublicKey().String(),
RemoteKey: remoteKey,
Body: &proto.Body{
Type: t,
Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd),
WgListenPort: uint32(myPort),
NetBirdVersion: version.NetbirdVersion(),
RosenpassConfig: &proto.RosenpassConfig{
RosenpassPubKey: rosenpassPubKey,
RosenpassServerAddr: rosenpassAddr,
},
RelayServerAddress: relaySrvAddress,
},
}, nil
}
// Credential is an instance of a GrpcClient's Credential
type Credential struct {
UFrag string
Pwd string
}

View File

@@ -0,0 +1,13 @@
package client
import (
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"testing"
)
func TestSignal(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Signal Suite")
}

View File

@@ -0,0 +1,230 @@
package client
import (
"context"
"net"
"sync"
"time"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata"
sigProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/server"
)
var _ = Describe("GrpcClient", func() {
var (
addr string
listener net.Listener
server *grpc.Server
)
BeforeEach(func() {
server, listener = startSignal()
addr = listener.Addr().String()
})
AfterEach(func() {
server.Stop()
listener.Close()
})
Describe("Exchanging messages", func() {
Context("between connected peers", func() {
It("should be successful", func() {
var msgReceived sync.WaitGroup
msgReceived.Add(2)
var payloadReceivedOnA string
var payloadReceivedOnB string
var featuresSupportedReceivedOnA []uint32
var featuresSupportedReceivedOnB []uint32
// connect PeerA to Signal
keyA, _ := wgtypes.GenerateKey()
clientA := createSignalClient(addr, keyA)
go func() {
err := clientA.Receive(context.Background(), func(msg *sigProto.Message) error {
payloadReceivedOnA = msg.GetBody().GetPayload()
featuresSupportedReceivedOnA = msg.GetBody().GetFeaturesSupported()
msgReceived.Done()
return nil
})
if err != nil {
return
}
}()
clientA.WaitStreamConnected()
// connect PeerB to Signal
keyB, _ := wgtypes.GenerateKey()
clientB := createSignalClient(addr, keyB)
go func() {
err := clientB.Receive(context.Background(), func(msg *sigProto.Message) error {
payloadReceivedOnB = msg.GetBody().GetPayload()
featuresSupportedReceivedOnB = msg.GetBody().GetFeaturesSupported()
err := clientB.Send(&sigProto.Message{
Key: keyB.PublicKey().String(),
RemoteKey: keyA.PublicKey().String(),
Body: &sigProto.Body{Payload: "pong"},
})
if err != nil {
Fail("failed sending a message to PeerA")
}
msgReceived.Done()
return nil
})
if err != nil {
return
}
}()
clientB.WaitStreamConnected()
// PeerA initiates ping-pong
err := clientA.Send(&sigProto.Message{
Key: keyA.PublicKey().String(),
RemoteKey: keyB.PublicKey().String(),
Body: &sigProto.Body{Payload: "ping", FeaturesSupported: []uint32{DirectCheck}},
})
if err != nil {
Fail("failed sending a message to PeerB")
}
if waitTimeout(&msgReceived, 3*time.Second) {
Fail("test timed out on waiting for peers to exchange messages")
}
Expect(payloadReceivedOnA).To(BeEquivalentTo("pong"))
Expect(payloadReceivedOnB).To(BeEquivalentTo("ping"))
Expect(featuresSupportedReceivedOnA).To(BeNil())
Expect(featuresSupportedReceivedOnB).To(ContainElements([]uint32{DirectCheck}))
})
})
})
Describe("Connecting to the Signal stream channel", func() {
Context("with a signal client", func() {
It("should be successful", func() {
key, _ := wgtypes.GenerateKey()
client := createSignalClient(addr, key)
go func() {
err := client.Receive(context.Background(), func(msg *sigProto.Message) error {
return nil
})
if err != nil {
return
}
}()
client.WaitStreamConnected()
Expect(client).NotTo(BeNil())
})
})
Context("with a raw client and no Id header", func() {
It("should fail", func() {
client := createRawSignalClient(addr)
stream, err := client.ConnectStream(context.Background())
if err != nil {
Fail("error connecting to stream")
}
_, err = stream.Recv()
Expect(stream).NotTo(BeNil())
Expect(err).NotTo(BeNil())
})
})
Context("with a raw client and with an Id header", func() {
It("should be successful", func() {
md := metadata.New(map[string]string{sigProto.HeaderId: "peer"})
ctx := metadata.NewOutgoingContext(context.Background(), md)
client := createRawSignalClient(addr)
stream, err := client.ConnectStream(ctx)
Expect(stream).NotTo(BeNil())
Expect(err).To(BeNil())
})
})
})
})
func createSignalClient(addr string, key wgtypes.Key) *GrpcClient {
var sigTLSEnabled = false
client, err := NewClient(context.Background(), addr, key, sigTLSEnabled)
if err != nil {
Fail("failed creating signal client")
}
return client
}
func createRawSignalClient(addr string) sigProto.SignalExchangeClient {
ctx := context.Background()
conn, err := grpc.DialContext(ctx, addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 3 * time.Second,
Timeout: 2 * time.Second,
}))
if err != nil {
Fail("failed creating raw signal client")
}
return sigProto.NewSignalExchangeClient(conn)
}
func startSignal() (*grpc.Server, net.Listener) {
lis, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
s := grpc.NewServer()
srv, err := server.NewServer(context.Background(), otel.Meter(""))
if err != nil {
panic(err)
}
sigProto.RegisterSignalExchangeServer(s, srv)
go func() {
if err := s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis
}
// waitTimeout waits for the waitgroup for the specified max timeout.
// Returns true if waiting timed out.
func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
c := make(chan struct{})
go func() {
defer close(c)
wg.Wait()
}()
select {
case <-c:
return false // completed normally
case <-time.After(timeout):
return true // timed out
}
}

View File

@@ -0,0 +1,10 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -0,0 +1,437 @@
package client
import (
"context"
"fmt"
"io"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/signal/proto"
nbgrpc "github.com/netbirdio/netbird/util/grpc"
)
// ConnStateNotifier is a wrapper interface of the status recorder
type ConnStateNotifier interface {
MarkSignalDisconnected(error)
MarkSignalConnected()
}
// GrpcClient Wraps the Signal Exchange Service gRpc client
type GrpcClient struct {
key wgtypes.Key
realClient proto.SignalExchangeClient
signalConn *grpc.ClientConn
ctx context.Context
stream proto.SignalExchange_ConnectStreamClient
// connectedCh used to notify goroutines waiting for the connection to the Signal stream
connectedCh chan struct{}
mux sync.Mutex
// StreamConnected indicates whether this client is StreamConnected to the Signal stream
status Status
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
onReconnectedListenerFn func()
}
func (c *GrpcClient) StreamConnected() bool {
return c.status == StreamConnected
}
func (c *GrpcClient) GetStatus() Status {
return c.status
}
// Close Closes underlying connections to the Signal Exchange
func (c *GrpcClient) Close() error {
return c.signalConn.Close()
}
// NewClient creates a new Signal client
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
var conn *grpc.ClientConn
operation := func() error {
var err error
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
if err != nil {
log.Printf("createConnection error: %v", err)
return err
}
return nil
}
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
if err != nil {
log.Errorf("failed to connect to the signalling server: %v", err)
return nil, err
}
log.Debugf("connected to Signal Service: %v", conn.Target())
return &GrpcClient{
realClient: proto.NewSignalExchangeClient(conn),
ctx: ctx,
signalConn: conn,
key: key,
mux: sync.Mutex{},
status: StreamDisconnected,
connStateCallbackLock: sync.RWMutex{},
}, nil
}
// SetConnStateListener set the ConnStateNotifier
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
c.connStateCallbackLock.Lock()
defer c.connStateCallbackLock.Unlock()
c.connStateCallback = notifier
}
// defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 1,
Multiplier: 1.7,
MaxInterval: 10 * time.Second,
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}, ctx)
}
// Receive Connects to the Signal Exchange message stream and starts receiving messages.
// The messages will be handled by msgHandler function provided.
// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error {
var backOff = defaultBackoff(ctx)
operation := func() error {
c.notifyStreamDisconnected()
log.Debugf("signal connection state %v", c.signalConn.GetState())
connState := c.signalConn.GetState()
if connState == connectivity.Shutdown {
return backoff.Permanent(fmt.Errorf("connection to signal has been shut down"))
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
c.signalConn.WaitForStateChange(ctx, connState)
return fmt.Errorf("connection to signal is not ready and in %s state", connState)
}
// connect to Signal stream identifying ourselves with a public WireGuard key
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
ctx, cancelStream := context.WithCancel(ctx)
defer cancelStream()
stream, err := c.connect(ctx, c.key.PublicKey().String())
if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
return err
}
c.notifyStreamConnected()
log.Infof("connected to the Signal Service stream")
c.notifyConnected()
// start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream, msgHandler)
if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
log.Debugf("signal connection context has been canceled, this usually indicates shutdown")
return nil
}
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
// reset times and next try will start with a long delay
backOff.Reset()
c.notifyDisconnected(err)
log.Warnf("disconnected from the Signal service but will retry silently. Reason: %v", err)
return err
}
return nil
}
err := backoff.Retry(operation, backOff)
if err != nil {
log.Errorf("exiting the Signal service connection retry loop due to the unrecoverable error: %v", err)
return err
}
return nil
}
func (c *GrpcClient) notifyStreamDisconnected() {
c.mux.Lock()
defer c.mux.Unlock()
c.status = StreamDisconnected
}
func (c *GrpcClient) notifyStreamConnected() {
c.mux.Lock()
defer c.mux.Unlock()
c.status = StreamConnected
if c.connectedCh != nil {
// there are goroutines waiting on this channel -> release them
close(c.connectedCh)
c.connectedCh = nil
}
if c.onReconnectedListenerFn != nil {
c.onReconnectedListenerFn()
}
}
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
c.mux.Lock()
defer c.mux.Unlock()
if c.connectedCh == nil {
c.connectedCh = make(chan struct{})
}
return c.connectedCh
}
func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
c.stream = nil
// add key fingerprint to the request header to be identified on the server side
md := metadata.New(map[string]string{proto.HeaderId: key})
metaCtx := metadata.NewOutgoingContext(ctx, md)
stream, err := c.realClient.ConnectStream(metaCtx, grpc.WaitForReady(true))
c.stream = stream
if err != nil {
return nil, err
}
// blocks
header, err := c.stream.Header()
if err != nil {
return nil, err
}
registered := header.Get(proto.HeaderRegistered)
if len(registered) == 0 {
return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams")
}
return stream, nil
}
// Ready indicates whether the client is okay and Ready to be used
// for now it just checks whether gRPC connection to the service is in state Ready
func (c *GrpcClient) Ready() bool {
return c.signalConn.GetState() == connectivity.Ready || c.signalConn.GetState() == connectivity.Idle
}
// IsHealthy probes the gRPC connection and returns false on errors
func (c *GrpcClient) IsHealthy() bool {
switch c.signalConn.GetState() {
case connectivity.TransientFailure:
return false
case connectivity.Connecting:
return true
case connectivity.Shutdown:
return true
case connectivity.Idle:
case connectivity.Ready:
}
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
defer cancel()
_, err := c.realClient.Send(ctx, &proto.EncryptedMessage{
Key: c.key.PublicKey().String(),
RemoteKey: "dummy",
Body: nil,
})
if err != nil {
c.notifyDisconnected(err)
log.Warnf("health check returned: %s", err)
return false
}
c.notifyConnected()
return true
}
// WaitStreamConnected waits until the client is connected to the Signal stream
func (c *GrpcClient) WaitStreamConnected() {
if c.status == StreamConnected {
return
}
ch := c.getStreamStatusChan()
select {
case <-c.ctx.Done():
case <-ch:
}
}
func (c *GrpcClient) SetOnReconnectedListener(fn func()) {
c.mux.Lock()
defer c.mux.Unlock()
c.onReconnectedListenerFn = fn
}
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
// The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
// GrpcClient.connWg can be used to wait
func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error {
if !c.Ready() {
return fmt.Errorf("no connection to signal")
}
if c.stream == nil {
return fmt.Errorf("connection to the Signal Exchange has not been established yet. Please call GrpcClient.Receive before sending messages")
}
err := c.stream.Send(msg)
if err != nil {
log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err)
return err
}
return nil
}
// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key
func (c *GrpcClient) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, error) {
remoteKey, err := wgtypes.ParseKey(msg.GetKey())
if err != nil {
return nil, err
}
body := &proto.Body{}
err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body)
if err != nil {
return nil, err
}
return &proto.Message{
Key: msg.Key,
RemoteKey: msg.RemoteKey,
Body: body,
}, nil
}
// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key
func (c *GrpcClient) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) {
remoteKey, err := wgtypes.ParseKey(msg.RemoteKey)
if err != nil {
return nil, err
}
encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body)
if err != nil {
return nil, err
}
return &proto.EncryptedMessage{
Key: msg.GetKey(),
RemoteKey: msg.GetRemoteKey(),
Body: encryptedBody,
}, nil
}
// Send sends a message to the remote Peer through the Signal Exchange.
func (c *GrpcClient) Send(msg *proto.Message) error {
if !c.Ready() {
return fmt.Errorf("no connection to signal")
}
encryptedMessage, err := c.encryptMessage(msg)
if err != nil {
return err
}
attemptTimeout := client.ConnectTimeout
for attempt := 0; attempt < 4; attempt++ {
if attempt > 1 {
attemptTimeout = time.Duration(attempt) * 5 * time.Second
}
ctx, cancel := context.WithTimeout(c.ctx, attemptTimeout)
_, err = c.realClient.Send(ctx, encryptedMessage)
cancel()
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
return err
}
if err == nil {
return nil
}
}
return err
}
// receive receives messages from other peers coming through the Signal Exchange
func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
msgHandler func(msg *proto.Message) error) error {
for {
msg, err := stream.Recv()
switch s, ok := status.FromError(err); {
case ok && s.Code() == codes.Canceled:
log.Debugf("stream canceled (usually indicates shutdown)")
return err
case s.Code() == codes.Unavailable:
log.Debugf("Signal Service is unavailable")
return err
case err == io.EOF:
log.Debugf("Signal Service stream closed by server")
return err
case err != nil:
return err
}
log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
decryptedMessage, err := c.decryptMessage(msg)
if err != nil {
log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
}
err = msgHandler(decryptedMessage)
if err != nil {
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
// todo send something??
}
}
}
func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkSignalDisconnected(err)
}
func (c *GrpcClient) notifyConnected() {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
if c.connStateCallback == nil {
return
}
c.connStateCallback.MarkSignalConnected()
}

View File

@@ -0,0 +1,84 @@
package client
import (
"context"
"github.com/netbirdio/netbird/shared/signal/proto"
)
type MockClient struct {
CloseFunc func() error
GetStatusFunc func() Status
StreamConnectedFunc func() bool
ReadyFunc func() bool
WaitStreamConnectedFunc func()
ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error
SendToStreamFunc func(msg *proto.EncryptedMessage) error
SendFunc func(msg *proto.Message) error
SetOnReconnectedListenerFunc func(f func())
}
// SetOnReconnectedListener sets the function to be called when the client reconnects.
func (sm *MockClient) SetOnReconnectedListener(_ func()) {
// Do nothing
}
func (sm *MockClient) IsHealthy() bool {
return true
}
func (sm *MockClient) Close() error {
if sm.CloseFunc == nil {
return nil
}
return sm.CloseFunc()
}
func (sm *MockClient) GetStatus() Status {
if sm.GetStatusFunc == nil {
return ""
}
return sm.GetStatusFunc()
}
func (sm *MockClient) StreamConnected() bool {
if sm.StreamConnectedFunc == nil {
return false
}
return sm.StreamConnectedFunc()
}
func (sm *MockClient) Ready() bool {
if sm.ReadyFunc == nil {
return false
}
return sm.ReadyFunc()
}
func (sm *MockClient) WaitStreamConnected() {
if sm.WaitStreamConnectedFunc == nil {
return
}
sm.WaitStreamConnectedFunc()
}
func (sm *MockClient) Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error {
if sm.ReceiveFunc == nil {
return nil
}
return sm.ReceiveFunc(ctx, msgHandler)
}
func (sm *MockClient) SendToStream(msg *proto.EncryptedMessage) error {
if sm.SendToStreamFunc == nil {
return nil
}
return sm.SendToStreamFunc(msg)
}
func (sm *MockClient) Send(msg *proto.Message) error {
if sm.SendFunc == nil {
return nil
}
return sm.SendFunc(msg)
}

View File

@@ -0,0 +1,5 @@
package proto
// protocol constants, field names that can be used by both client and server
const HeaderId = "x-wiretrustee-peer-id"
const HeaderRegistered = "x-wiretrustee-peer-registered"

17
shared/signal/proto/generate.sh Executable file
View File

@@ -0,0 +1,17 @@
#!/bin/bash
set -e
if ! which realpath > /dev/null 2>&1
then
echo realpath is not installed
echo run: brew install coreutils
exit 1
fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I ./ ./signalexchange.proto --go_out=../ --go-grpc_out=../
cd "$old_pwd"

View File

@@ -0,0 +1,2 @@
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=

View File

@@ -0,0 +1,624 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.21.12
// source: signalexchange.proto
package proto
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
_ "google.golang.org/protobuf/types/descriptorpb"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// Message type
type Body_Type int32
const (
Body_OFFER Body_Type = 0
Body_ANSWER Body_Type = 1
Body_CANDIDATE Body_Type = 2
Body_MODE Body_Type = 4
Body_GO_IDLE Body_Type = 5
)
// Enum value maps for Body_Type.
var (
Body_Type_name = map[int32]string{
0: "OFFER",
1: "ANSWER",
2: "CANDIDATE",
4: "MODE",
5: "GO_IDLE",
}
Body_Type_value = map[string]int32{
"OFFER": 0,
"ANSWER": 1,
"CANDIDATE": 2,
"MODE": 4,
"GO_IDLE": 5,
}
)
func (x Body_Type) Enum() *Body_Type {
p := new(Body_Type)
*p = x
return p
}
func (x Body_Type) String() string {
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
}
func (Body_Type) Descriptor() protoreflect.EnumDescriptor {
return file_signalexchange_proto_enumTypes[0].Descriptor()
}
func (Body_Type) Type() protoreflect.EnumType {
return &file_signalexchange_proto_enumTypes[0]
}
func (x Body_Type) Number() protoreflect.EnumNumber {
return protoreflect.EnumNumber(x)
}
// Deprecated: Use Body_Type.Descriptor instead.
func (Body_Type) EnumDescriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{2, 0}
}
// Used for sending through signal.
// The body of this message is the Body message encrypted with the Wireguard private key and the remote Peer key
type EncryptedMessage struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Wireguard public key
Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"`
// Wireguard public key of the remote peer to connect to
RemoteKey string `protobuf:"bytes,3,opt,name=remoteKey,proto3" json:"remoteKey,omitempty"`
// encrypted message Body
Body []byte `protobuf:"bytes,4,opt,name=body,proto3" json:"body,omitempty"`
}
func (x *EncryptedMessage) Reset() {
*x = EncryptedMessage{}
if protoimpl.UnsafeEnabled {
mi := &file_signalexchange_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *EncryptedMessage) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*EncryptedMessage) ProtoMessage() {}
func (x *EncryptedMessage) ProtoReflect() protoreflect.Message {
mi := &file_signalexchange_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use EncryptedMessage.ProtoReflect.Descriptor instead.
func (*EncryptedMessage) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{0}
}
func (x *EncryptedMessage) GetKey() string {
if x != nil {
return x.Key
}
return ""
}
func (x *EncryptedMessage) GetRemoteKey() string {
if x != nil {
return x.RemoteKey
}
return ""
}
func (x *EncryptedMessage) GetBody() []byte {
if x != nil {
return x.Body
}
return nil
}
// A decrypted representation of the EncryptedMessage. Used locally before/after encryption
type Message struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// WireGuard public key
Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"`
// WireGuard public key of the remote peer to connect to
RemoteKey string `protobuf:"bytes,3,opt,name=remoteKey,proto3" json:"remoteKey,omitempty"`
Body *Body `protobuf:"bytes,4,opt,name=body,proto3" json:"body,omitempty"`
}
func (x *Message) Reset() {
*x = Message{}
if protoimpl.UnsafeEnabled {
mi := &file_signalexchange_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Message) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Message) ProtoMessage() {}
func (x *Message) ProtoReflect() protoreflect.Message {
mi := &file_signalexchange_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Message.ProtoReflect.Descriptor instead.
func (*Message) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{1}
}
func (x *Message) GetKey() string {
if x != nil {
return x.Key
}
return ""
}
func (x *Message) GetRemoteKey() string {
if x != nil {
return x.RemoteKey
}
return ""
}
func (x *Message) GetBody() *Body {
if x != nil {
return x.Body
}
return nil
}
// Actual body of the message that can contain credentials (type OFFER/ANSWER) or connection Candidate
// This part will be encrypted
type Body struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Type Body_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Body_Type" json:"type,omitempty"`
Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
// wgListenPort is an actual WireGuard listen port
WgListenPort uint32 `protobuf:"varint,3,opt,name=wgListenPort,proto3" json:"wgListenPort,omitempty"`
NetBirdVersion string `protobuf:"bytes,4,opt,name=netBirdVersion,proto3" json:"netBirdVersion,omitempty"`
Mode *Mode `protobuf:"bytes,5,opt,name=mode,proto3" json:"mode,omitempty"`
// featuresSupported list of supported features by the client of this protocol
FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"`
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
// relayServerAddress is url of the relay server
RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"`
}
func (x *Body) Reset() {
*x = Body{}
if protoimpl.UnsafeEnabled {
mi := &file_signalexchange_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Body) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Body) ProtoMessage() {}
func (x *Body) ProtoReflect() protoreflect.Message {
mi := &file_signalexchange_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Body.ProtoReflect.Descriptor instead.
func (*Body) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{2}
}
func (x *Body) GetType() Body_Type {
if x != nil {
return x.Type
}
return Body_OFFER
}
func (x *Body) GetPayload() string {
if x != nil {
return x.Payload
}
return ""
}
func (x *Body) GetWgListenPort() uint32 {
if x != nil {
return x.WgListenPort
}
return 0
}
func (x *Body) GetNetBirdVersion() string {
if x != nil {
return x.NetBirdVersion
}
return ""
}
func (x *Body) GetMode() *Mode {
if x != nil {
return x.Mode
}
return nil
}
func (x *Body) GetFeaturesSupported() []uint32 {
if x != nil {
return x.FeaturesSupported
}
return nil
}
func (x *Body) GetRosenpassConfig() *RosenpassConfig {
if x != nil {
return x.RosenpassConfig
}
return nil
}
func (x *Body) GetRelayServerAddress() string {
if x != nil {
return x.RelayServerAddress
}
return ""
}
// Mode indicates a connection mode
type Mode struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Direct *bool `protobuf:"varint,1,opt,name=direct,proto3,oneof" json:"direct,omitempty"`
}
func (x *Mode) Reset() {
*x = Mode{}
if protoimpl.UnsafeEnabled {
mi := &file_signalexchange_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Mode) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Mode) ProtoMessage() {}
func (x *Mode) ProtoReflect() protoreflect.Message {
mi := &file_signalexchange_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Mode.ProtoReflect.Descriptor instead.
func (*Mode) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{3}
}
func (x *Mode) GetDirect() bool {
if x != nil && x.Direct != nil {
return *x.Direct
}
return false
}
type RosenpassConfig struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
RosenpassPubKey []byte `protobuf:"bytes,1,opt,name=rosenpassPubKey,proto3" json:"rosenpassPubKey,omitempty"`
// rosenpassServerAddr is an IP:port of the rosenpass service
RosenpassServerAddr string `protobuf:"bytes,2,opt,name=rosenpassServerAddr,proto3" json:"rosenpassServerAddr,omitempty"`
}
func (x *RosenpassConfig) Reset() {
*x = RosenpassConfig{}
if protoimpl.UnsafeEnabled {
mi := &file_signalexchange_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *RosenpassConfig) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RosenpassConfig) ProtoMessage() {}
func (x *RosenpassConfig) ProtoReflect() protoreflect.Message {
mi := &file_signalexchange_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RosenpassConfig.ProtoReflect.Descriptor instead.
func (*RosenpassConfig) Descriptor() ([]byte, []int) {
return file_signalexchange_proto_rawDescGZIP(), []int{4}
}
func (x *RosenpassConfig) GetRosenpassPubKey() []byte {
if x != nil {
return x.RosenpassPubKey
}
return nil
}
func (x *RosenpassConfig) GetRosenpassServerAddr() string {
if x != nil {
return x.RosenpassServerAddr
}
return ""
}
var File_signalexchange_proto protoreflect.FileDescriptor
var file_signalexchange_proto_rawDesc = []byte{
0x0a, 0x14, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x20, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74,
0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x56, 0x0a, 0x10, 0x45, 0x6e, 0x63, 0x72,
0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x10, 0x0a, 0x03,
0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x1c,
0x0a, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28,
0x09, 0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04,
0x62, 0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79,
0x22, 0x63, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b,
0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x1c, 0x0a,
0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xb3, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x22, 0x0a, 0x0c, 0x77, 0x67, 0x4c, 0x69, 0x73,
0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x77,
0x67, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x26, 0x0a, 0x0e, 0x6e,
0x65, 0x74, 0x42, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20,
0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x42, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73,
0x69, 0x6f, 0x6e, 0x12, 0x28, 0x0a, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28,
0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e,
0x67, 0x65, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x12, 0x2c, 0x0a,
0x11, 0x66, 0x65, 0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74,
0x65, 0x64, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x11, 0x66, 0x65, 0x61, 0x74, 0x75, 0x72,
0x65, 0x73, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x12, 0x49, 0x0a, 0x0f, 0x72,
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x07,
0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01,
0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41,
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x12, 0x0b,
0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x22, 0x2e, 0x0a, 0x04, 0x4d,
0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01,
0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, 0x52,
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28,
0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65,
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61,
0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65,
0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18,
0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a,
0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c,
0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43,
0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e,
0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20,
0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e,
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_signalexchange_proto_rawDescOnce sync.Once
file_signalexchange_proto_rawDescData = file_signalexchange_proto_rawDesc
)
func file_signalexchange_proto_rawDescGZIP() []byte {
file_signalexchange_proto_rawDescOnce.Do(func() {
file_signalexchange_proto_rawDescData = protoimpl.X.CompressGZIP(file_signalexchange_proto_rawDescData)
})
return file_signalexchange_proto_rawDescData
}
var file_signalexchange_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_signalexchange_proto_goTypes = []interface{}{
(Body_Type)(0), // 0: signalexchange.Body.Type
(*EncryptedMessage)(nil), // 1: signalexchange.EncryptedMessage
(*Message)(nil), // 2: signalexchange.Message
(*Body)(nil), // 3: signalexchange.Body
(*Mode)(nil), // 4: signalexchange.Mode
(*RosenpassConfig)(nil), // 5: signalexchange.RosenpassConfig
}
var file_signalexchange_proto_depIdxs = []int32{
3, // 0: signalexchange.Message.body:type_name -> signalexchange.Body
0, // 1: signalexchange.Body.type:type_name -> signalexchange.Body.Type
4, // 2: signalexchange.Body.mode:type_name -> signalexchange.Mode
5, // 3: signalexchange.Body.rosenpassConfig:type_name -> signalexchange.RosenpassConfig
1, // 4: signalexchange.SignalExchange.Send:input_type -> signalexchange.EncryptedMessage
1, // 5: signalexchange.SignalExchange.ConnectStream:input_type -> signalexchange.EncryptedMessage
1, // 6: signalexchange.SignalExchange.Send:output_type -> signalexchange.EncryptedMessage
1, // 7: signalexchange.SignalExchange.ConnectStream:output_type -> signalexchange.EncryptedMessage
6, // [6:8] is the sub-list for method output_type
4, // [4:6] is the sub-list for method input_type
4, // [4:4] is the sub-list for extension type_name
4, // [4:4] is the sub-list for extension extendee
0, // [0:4] is the sub-list for field type_name
}
func init() { file_signalexchange_proto_init() }
func file_signalexchange_proto_init() {
if File_signalexchange_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_signalexchange_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*EncryptedMessage); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_signalexchange_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Message); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_signalexchange_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Body); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_signalexchange_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Mode); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_signalexchange_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*RosenpassConfig); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_signalexchange_proto_rawDesc,
NumEnums: 1,
NumMessages: 5,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_signalexchange_proto_goTypes,
DependencyIndexes: file_signalexchange_proto_depIdxs,
EnumInfos: file_signalexchange_proto_enumTypes,
MessageInfos: file_signalexchange_proto_msgTypes,
}.Build()
File_signalexchange_proto = out.File
file_signalexchange_proto_rawDesc = nil
file_signalexchange_proto_goTypes = nil
file_signalexchange_proto_depIdxs = nil
}

View File

@@ -0,0 +1,78 @@
syntax = "proto3";
import "google/protobuf/descriptor.proto";
option go_package = "/proto";
package signalexchange;
service SignalExchange {
// Synchronously connect to the Signal Exchange service offering connection candidates and waiting for connection candidates from the other party (remote peer)
rpc Send(EncryptedMessage) returns (EncryptedMessage) {}
// Connect to the Signal Exchange service offering connection candidates and maintain a channel for receiving candidates from the other party (remote peer)
rpc ConnectStream(stream EncryptedMessage) returns (stream EncryptedMessage) {}
}
// Used for sending through signal.
// The body of this message is the Body message encrypted with the Wireguard private key and the remote Peer key
message EncryptedMessage {
// Wireguard public key
string key = 2;
// Wireguard public key of the remote peer to connect to
string remoteKey = 3;
// encrypted message Body
bytes body = 4;
}
// A decrypted representation of the EncryptedMessage. Used locally before/after encryption
message Message {
// WireGuard public key
string key = 2;
// WireGuard public key of the remote peer to connect to
string remoteKey = 3;
Body body = 4;
}
// Actual body of the message that can contain credentials (type OFFER/ANSWER) or connection Candidate
// This part will be encrypted
message Body {
// Message type
enum Type {
OFFER = 0;
ANSWER = 1;
CANDIDATE = 2;
MODE = 4;
GO_IDLE = 5;
}
Type type = 1;
string payload = 2;
// wgListenPort is an actual WireGuard listen port
uint32 wgListenPort = 3;
string netBirdVersion = 4;
Mode mode = 5;
// featuresSupported list of supported features by the client of this protocol
repeated uint32 featuresSupported = 6;
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
RosenpassConfig rosenpassConfig = 7;
// relayServerAddress is url of the relay server
string relayServerAddress = 8;
}
// Mode indicates a connection mode
message Mode {
optional bool direct = 1;
}
message RosenpassConfig {
bytes rosenpassPubKey = 1;
// rosenpassServerAddr is an IP:port of the rosenpass service
string rosenpassServerAddr = 2;
}

View File

@@ -0,0 +1,174 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
package proto
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// SignalExchangeClient is the client API for SignalExchange service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type SignalExchangeClient interface {
// Synchronously connect to the Signal Exchange service offering connection candidates and waiting for connection candidates from the other party (remote peer)
Send(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
// Connect to the Signal Exchange service offering connection candidates and maintain a channel for receiving candidates from the other party (remote peer)
ConnectStream(ctx context.Context, opts ...grpc.CallOption) (SignalExchange_ConnectStreamClient, error)
}
type signalExchangeClient struct {
cc grpc.ClientConnInterface
}
func NewSignalExchangeClient(cc grpc.ClientConnInterface) SignalExchangeClient {
return &signalExchangeClient{cc}
}
func (c *signalExchangeClient) Send(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
out := new(EncryptedMessage)
err := c.cc.Invoke(ctx, "/signalexchange.SignalExchange/Send", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *signalExchangeClient) ConnectStream(ctx context.Context, opts ...grpc.CallOption) (SignalExchange_ConnectStreamClient, error) {
stream, err := c.cc.NewStream(ctx, &SignalExchange_ServiceDesc.Streams[0], "/signalexchange.SignalExchange/ConnectStream", opts...)
if err != nil {
return nil, err
}
x := &signalExchangeConnectStreamClient{stream}
return x, nil
}
type SignalExchange_ConnectStreamClient interface {
Send(*EncryptedMessage) error
Recv() (*EncryptedMessage, error)
grpc.ClientStream
}
type signalExchangeConnectStreamClient struct {
grpc.ClientStream
}
func (x *signalExchangeConnectStreamClient) Send(m *EncryptedMessage) error {
return x.ClientStream.SendMsg(m)
}
func (x *signalExchangeConnectStreamClient) Recv() (*EncryptedMessage, error) {
m := new(EncryptedMessage)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// SignalExchangeServer is the server API for SignalExchange service.
// All implementations must embed UnimplementedSignalExchangeServer
// for forward compatibility
type SignalExchangeServer interface {
// Synchronously connect to the Signal Exchange service offering connection candidates and waiting for connection candidates from the other party (remote peer)
Send(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
// Connect to the Signal Exchange service offering connection candidates and maintain a channel for receiving candidates from the other party (remote peer)
ConnectStream(SignalExchange_ConnectStreamServer) error
mustEmbedUnimplementedSignalExchangeServer()
}
// UnimplementedSignalExchangeServer must be embedded to have forward compatible implementations.
type UnimplementedSignalExchangeServer struct {
}
func (UnimplementedSignalExchangeServer) Send(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
return nil, status.Errorf(codes.Unimplemented, "method Send not implemented")
}
func (UnimplementedSignalExchangeServer) ConnectStream(SignalExchange_ConnectStreamServer) error {
return status.Errorf(codes.Unimplemented, "method ConnectStream not implemented")
}
func (UnimplementedSignalExchangeServer) mustEmbedUnimplementedSignalExchangeServer() {}
// UnsafeSignalExchangeServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to SignalExchangeServer will
// result in compilation errors.
type UnsafeSignalExchangeServer interface {
mustEmbedUnimplementedSignalExchangeServer()
}
func RegisterSignalExchangeServer(s grpc.ServiceRegistrar, srv SignalExchangeServer) {
s.RegisterService(&SignalExchange_ServiceDesc, srv)
}
func _SignalExchange_Send_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(EncryptedMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(SignalExchangeServer).Send(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/signalexchange.SignalExchange/Send",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(SignalExchangeServer).Send(ctx, req.(*EncryptedMessage))
}
return interceptor(ctx, in, info, handler)
}
func _SignalExchange_ConnectStream_Handler(srv interface{}, stream grpc.ServerStream) error {
return srv.(SignalExchangeServer).ConnectStream(&signalExchangeConnectStreamServer{stream})
}
type SignalExchange_ConnectStreamServer interface {
Send(*EncryptedMessage) error
Recv() (*EncryptedMessage, error)
grpc.ServerStream
}
type signalExchangeConnectStreamServer struct {
grpc.ServerStream
}
func (x *signalExchangeConnectStreamServer) Send(m *EncryptedMessage) error {
return x.ServerStream.SendMsg(m)
}
func (x *signalExchangeConnectStreamServer) Recv() (*EncryptedMessage, error) {
m := new(EncryptedMessage)
if err := x.ServerStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
// SignalExchange_ServiceDesc is the grpc.ServiceDesc for SignalExchange service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var SignalExchange_ServiceDesc = grpc.ServiceDesc{
ServiceName: "signalexchange.SignalExchange",
HandlerType: (*SignalExchangeServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Send",
Handler: _SignalExchange_Send_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "ConnectStream",
Handler: _SignalExchange_ConnectStream_Handler,
ServerStreams: true,
ClientStreams: true,
},
},
Metadata: "signalexchange.proto",
}