mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[misc] Move shared components to shared directory (#4286)
Moved the following directories: ``` - management/client → shared/management/client - management/domain → shared/management/domain - management/proto → shared/management/proto - signal/client → shared/signal/client - signal/proto → shared/signal/proto - relay/client → shared/relay/client - relay/auth → shared/relay/auth ``` and adjusted import paths
This commit is contained in:
26
shared/management/client/client.go
Normal file
26
shared/management/client/client.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
io.Closer
|
||||
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||
GetServerPublicKey() (*wgtypes.Key, error)
|
||||
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
||||
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
|
||||
IsHealthy() bool
|
||||
SyncMeta(sysInfo *system.Info) error
|
||||
Logout() error
|
||||
}
|
||||
554
shared/management/client/client_test.go
Normal file
554
shared/management/client/client_test.go
Normal file
@@ -0,0 +1,554 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
mgmt "github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("debug", util.LogConsole)
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
||||
t.Helper()
|
||||
level, _ := log.ParseLevel("debug")
|
||||
log.SetLevel(level)
|
||||
|
||||
config := &types.Config{}
|
||||
_, err := util.ReadJson("../server/testdata/management.json", config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(cleanUp)
|
||||
|
||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
GetSettings(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).
|
||||
Return(&types.Settings{}, nil).
|
||||
AnyTimes()
|
||||
settingsMockManager.
|
||||
EXPECT().
|
||||
GetExtraSettings(gomock.Any(), gomock.Any()).
|
||||
Return(&types.ExtraSettings{}, nil).
|
||||
AnyTimes()
|
||||
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
permissionsManagerMock.
|
||||
EXPECT().
|
||||
ValidateUserPermissions(
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
gomock.Any(),
|
||||
).
|
||||
Return(true, nil).
|
||||
AnyTimes()
|
||||
|
||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil, mgmt.MockIntegratedValidator{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis
|
||||
}
|
||||
|
||||
func startMockManagement(t *testing.T) (*grpc.Server, net.Listener, *mock_server.ManagementServiceServerMock, wgtypes.Key) {
|
||||
t.Helper()
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s := grpc.NewServer()
|
||||
|
||||
serverKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mgmtMockServer := &mock_server.ManagementServiceServerMock{
|
||||
GetServerKeyFunc: func(context.Context, *mgmtProto.Empty) (*mgmtProto.ServerKeyResponse, error) {
|
||||
response := &mgmtProto.ServerKeyResponse{
|
||||
Key: serverKey.PublicKey().String(),
|
||||
}
|
||||
return response, nil
|
||||
},
|
||||
}
|
||||
|
||||
mgmtProto.RegisterManagementServiceServer(s, mgmtMockServer)
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis, mgmtMockServer, serverKey
|
||||
}
|
||||
|
||||
func closeManagementSilently(s *grpc.Server, listener net.Listener) {
|
||||
s.GracefulStop()
|
||||
err := listener.Close()
|
||||
if err != nil {
|
||||
log.Warnf("error while closing management listener %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_GetServerPublicKey(t *testing.T) {
|
||||
testKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
s, listener := startManagement(t)
|
||||
defer closeManagementSilently(s, listener)
|
||||
|
||||
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Error("couldn't retrieve management public key")
|
||||
}
|
||||
if key == nil {
|
||||
t.Error("got an empty management public key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
|
||||
testKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
s, listener := startManagement(t)
|
||||
defer closeManagementSilently(s, listener)
|
||||
|
||||
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
key, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sysInfo := system.GetInfo(context.TODO())
|
||||
_, err = client.Login(*key, sysInfo, nil, nil)
|
||||
if err == nil {
|
||||
t.Error("expecting err on unregistered login, got nil")
|
||||
}
|
||||
if s, ok := status.FromError(err); !ok || s.Code() != codes.PermissionDenied {
|
||||
t.Errorf("expecting err code %d denied on unregistered login got %d", codes.PermissionDenied, s.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_LoginRegistered(t *testing.T) {
|
||||
testKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
s, listener := startManagement(t)
|
||||
defer closeManagementSilently(s, listener)
|
||||
|
||||
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
key, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
info := system.GetInfo(context.TODO())
|
||||
resp, err := client.Register(*key, ValidKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
t.Error("expecting non nil response, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_Sync(t *testing.T) {
|
||||
testKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
s, listener := startManagement(t)
|
||||
defer closeManagementSilently(s, listener)
|
||||
|
||||
client, err := NewClient(ctx, listener.Addr().String(), testKey, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverKey, err := client.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
info := system.GetInfo(context.TODO())
|
||||
_, err = client.Register(*serverKey, ValidKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// create and register second peer (we should receive on Sync request)
|
||||
remoteKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
remoteClient, err := NewClient(context.TODO(), listener.Addr().String(), remoteKey, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
info = system.GetInfo(context.TODO())
|
||||
_, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ch := make(chan *mgmtProto.SyncResponse, 1)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
err = client.Sync(ctx, info, func(msg *mgmtProto.SyncResponse) error {
|
||||
ch <- msg
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case resp := <-ch:
|
||||
if resp.GetPeerConfig() == nil {
|
||||
t.Error("expecting non nil PeerConfig got nil")
|
||||
}
|
||||
if resp.GetNetbirdConfig() == nil {
|
||||
t.Error("expecting non nil NetbirdConfig got nil")
|
||||
}
|
||||
if len(resp.GetRemotePeers()) != 1 {
|
||||
t.Errorf("expecting RemotePeers size %d got %d", 1, len(resp.GetRemotePeers()))
|
||||
return
|
||||
}
|
||||
if resp.GetRemotePeersIsEmpty() == true {
|
||||
t.Error("expecting RemotePeers property to be false, got true")
|
||||
}
|
||||
if resp.GetRemotePeers()[0].GetWgPubKey() != remoteKey.PublicKey().String() {
|
||||
t.Errorf("expecting RemotePeer public key %s got %s", remoteKey.PublicKey().String(), resp.GetRemotePeers()[0].GetWgPubKey())
|
||||
}
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Error("timeout waiting for test to finish")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SystemMetaDataFromClient(t *testing.T) {
|
||||
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
|
||||
defer s.GracefulStop()
|
||||
|
||||
testKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverAddr := lis.Addr().String()
|
||||
ctx := context.Background()
|
||||
|
||||
testClient, err := NewClient(ctx, serverAddr, testKey, false)
|
||||
if err != nil {
|
||||
t.Fatalf("error while creating testClient: %v", err)
|
||||
}
|
||||
|
||||
key, err := testClient.GetServerPublicKey()
|
||||
if err != nil {
|
||||
t.Fatalf("error while getting server public key from testclient, %v", err)
|
||||
}
|
||||
|
||||
var actualMeta *mgmtProto.PeerSystemMeta
|
||||
var actualValidKey string
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
mgmtMockServer.LoginFunc = func(ctx context.Context, msg *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
peerKey, err := wgtypes.ParseKey(msg.GetWgPubKey())
|
||||
if err != nil {
|
||||
log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", msg.WgPubKey)
|
||||
return nil, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", msg.WgPubKey)
|
||||
}
|
||||
|
||||
loginReq := &mgmtProto.LoginRequest{}
|
||||
err = encryption.DecryptMessage(peerKey, serverKey, msg.Body, loginReq)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
actualMeta = loginReq.GetMeta()
|
||||
actualValidKey = loginReq.GetSetupKey()
|
||||
wg.Done()
|
||||
|
||||
loginResp := &mgmtProto.LoginResponse{}
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, serverKey, loginResp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: serverKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
Version: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
info := system.GetInfo(context.TODO())
|
||||
_, err = testClient.Register(*key, ValidKey, "", info, nil, nil)
|
||||
if err != nil {
|
||||
t.Errorf("error while trying to register client: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
protoNetAddr := make([]*mgmtProto.NetworkAddress, 0, len(info.NetworkAddresses))
|
||||
for _, addr := range info.NetworkAddresses {
|
||||
protoNetAddr = append(protoNetAddr, &mgmtProto.NetworkAddress{
|
||||
NetIP: addr.NetIP.String(),
|
||||
Mac: addr.Mac,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
expectedMeta := &mgmtProto.PeerSystemMeta{
|
||||
Hostname: info.Hostname,
|
||||
GoOS: info.GoOS,
|
||||
Kernel: info.Kernel,
|
||||
Platform: info.Platform,
|
||||
OS: info.OS,
|
||||
Core: info.OSVersion,
|
||||
OSVersion: info.OSVersion,
|
||||
NetbirdVersion: info.NetbirdVersion,
|
||||
KernelVersion: info.KernelVersion,
|
||||
|
||||
NetworkAddresses: protoNetAddr,
|
||||
SysSerialNumber: info.SystemSerialNumber,
|
||||
SysProductName: info.SystemProductName,
|
||||
SysManufacturer: info.SystemManufacturer,
|
||||
Environment: &mgmtProto.Environment{Cloud: info.Environment.Cloud, Platform: info.Environment.Platform},
|
||||
}
|
||||
|
||||
assert.Equal(t, ValidKey, actualValidKey)
|
||||
if !isEqual(expectedMeta, actualMeta) {
|
||||
t.Errorf("expected and actual meta are not equal")
|
||||
}
|
||||
}
|
||||
|
||||
func isEqual(a, b *mgmtProto.PeerSystemMeta) bool {
|
||||
if len(a.NetworkAddresses) != len(b.NetworkAddresses) {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, addr := range a.GetNetworkAddresses() {
|
||||
var found bool
|
||||
for _, oAddr := range b.GetNetworkAddresses() {
|
||||
if addr.GetMac() == oAddr.GetMac() && addr.GetNetIP() == oAddr.GetNetIP() {
|
||||
found = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("------")
|
||||
|
||||
return a.GetHostname() == b.GetHostname() &&
|
||||
a.GetGoOS() == b.GetGoOS() &&
|
||||
a.GetKernel() == b.GetKernel() &&
|
||||
a.GetKernelVersion() == b.GetKernelVersion() &&
|
||||
a.GetCore() == b.GetCore() &&
|
||||
a.GetPlatform() == b.GetPlatform() &&
|
||||
a.GetOS() == b.GetOS() &&
|
||||
a.GetOSVersion() == b.GetOSVersion() &&
|
||||
a.GetNetbirdVersion() == b.GetNetbirdVersion() &&
|
||||
a.GetUiVersion() == b.GetUiVersion() &&
|
||||
a.GetSysSerialNumber() == b.GetSysSerialNumber() &&
|
||||
a.GetSysProductName() == b.GetSysProductName() &&
|
||||
a.GetSysManufacturer() == b.GetSysManufacturer() &&
|
||||
a.GetEnvironment().Cloud == b.GetEnvironment().Cloud &&
|
||||
a.GetEnvironment().Platform == b.GetEnvironment().Platform
|
||||
}
|
||||
|
||||
func Test_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
|
||||
defer s.GracefulStop()
|
||||
|
||||
testKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverAddr := lis.Addr().String()
|
||||
ctx := context.Background()
|
||||
|
||||
client, err := NewClient(ctx, serverAddr, testKey, false)
|
||||
if err != nil {
|
||||
t.Fatalf("error while creating testClient: %v", err)
|
||||
}
|
||||
|
||||
expectedFlowInfo := &mgmtProto.DeviceAuthorizationFlow{
|
||||
Provider: 0,
|
||||
ProviderConfig: &mgmtProto.ProviderConfig{ClientID: "client"},
|
||||
}
|
||||
|
||||
mgmtMockServer.GetDeviceAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: serverKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
Version: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
flowInfo, err := client.GetDeviceAuthorizationFlow(serverKey)
|
||||
if err != nil {
|
||||
t.Error("error while retrieving device auth flow information")
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedFlowInfo.Provider, flowInfo.Provider, "provider should match")
|
||||
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
|
||||
}
|
||||
|
||||
func Test_GetPKCEAuthorizationFlow(t *testing.T) {
|
||||
s, lis, mgmtMockServer, serverKey := startMockManagement(t)
|
||||
defer s.GracefulStop()
|
||||
|
||||
testKey, err := wgtypes.GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
serverAddr := lis.Addr().String()
|
||||
ctx := context.Background()
|
||||
|
||||
client, err := NewClient(ctx, serverAddr, testKey, false)
|
||||
if err != nil {
|
||||
t.Fatalf("error while creating testClient: %v", err)
|
||||
}
|
||||
|
||||
expectedFlowInfo := &mgmtProto.PKCEAuthorizationFlow{
|
||||
ProviderConfig: &mgmtProto.ProviderConfig{
|
||||
ClientID: "client",
|
||||
ClientSecret: "secret",
|
||||
},
|
||||
}
|
||||
|
||||
mgmtMockServer.GetPKCEAuthorizationFlowFunc = func(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
|
||||
encryptedResp, err := encryption.EncryptMessage(serverKey, client.key, expectedFlowInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mgmtProto.EncryptedMessage{
|
||||
WgPubKey: serverKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
Version: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
flowInfo, err := client.GetPKCEAuthorizationFlow(serverKey)
|
||||
if err != nil {
|
||||
t.Error("error while retrieving pkce auth flow information")
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientID, flowInfo.ProviderConfig.ClientID, "provider configured client ID should match")
|
||||
assert.Equal(t, expectedFlowInfo.ProviderConfig.ClientSecret, flowInfo.ProviderConfig.ClientSecret, "provider configured client secret should match")
|
||||
}
|
||||
19
shared/management/client/common/types.go
Normal file
19
shared/management/client/common/types.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package common
|
||||
|
||||
// LoginFlag introduces additional login flags to the PKCE authorization request
|
||||
type LoginFlag uint8
|
||||
|
||||
const (
|
||||
// LoginFlagPrompt adds prompt=login to the authorization request
|
||||
LoginFlagPrompt LoginFlag = iota
|
||||
// LoginFlagMaxAge0 adds max_age=0 to the authorization request
|
||||
LoginFlagMaxAge0
|
||||
)
|
||||
|
||||
func (l LoginFlag) IsPromptLogin() bool {
|
||||
return l == LoginFlagPrompt
|
||||
}
|
||||
|
||||
func (l LoginFlag) IsMaxAge0Login() bool {
|
||||
return l == LoginFlagMaxAge0
|
||||
}
|
||||
3
shared/management/client/go.sum
Normal file
3
shared/management/client/go.sum
Normal file
@@ -0,0 +1,3 @@
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
|
||||
584
shared/management/client/grpc.go
Normal file
584
shared/management/client/grpc.go
Normal file
@@ -0,0 +1,584 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
nbgrpc "github.com/netbirdio/netbird/util/grpc"
|
||||
)
|
||||
|
||||
const ConnectTimeout = 10 * time.Second
|
||||
|
||||
const (
|
||||
errMsgMgmtPublicKey = "failed getting Management Service public key: %s"
|
||||
errMsgNoMgmtConnection = "no connection to management"
|
||||
)
|
||||
|
||||
// ConnStateNotifier is a wrapper interface of the status recorders
|
||||
type ConnStateNotifier interface {
|
||||
MarkManagementDisconnected(error)
|
||||
MarkManagementConnected()
|
||||
}
|
||||
|
||||
type GrpcClient struct {
|
||||
key wgtypes.Key
|
||||
realClient proto.ManagementServiceClient
|
||||
ctx context.Context
|
||||
conn *grpc.ClientConn
|
||||
connStateCallback ConnStateNotifier
|
||||
connStateCallbackLock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewClient creates a new client to Management service
|
||||
func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
|
||||
var conn *grpc.ClientConn
|
||||
|
||||
operation := func() error {
|
||||
var err error
|
||||
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
|
||||
if err != nil {
|
||||
log.Printf("createConnection error: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
|
||||
if err != nil {
|
||||
log.Errorf("failed creating connection to Management Service: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
realClient := proto.NewManagementServiceClient(conn)
|
||||
|
||||
return &GrpcClient{
|
||||
key: ourPrivateKey,
|
||||
realClient: realClient,
|
||||
ctx: ctx,
|
||||
conn: conn,
|
||||
connStateCallbackLock: sync.RWMutex{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes connection to the Management Service
|
||||
func (c *GrpcClient) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// SetConnStateListener set the ConnStateNotifier
|
||||
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
|
||||
c.connStateCallbackLock.Lock()
|
||||
defer c.connStateCallbackLock.Unlock()
|
||||
c.connStateCallback = notifier
|
||||
}
|
||||
|
||||
// defaultBackoff is a basic backoff mechanism for general issues
|
||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 1,
|
||||
Multiplier: 1.7,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
// ready indicates whether the client is okay and ready to be used
|
||||
// for now it just checks whether gRPC connection to the service is ready
|
||||
func (c *GrpcClient) ready() bool {
|
||||
return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle
|
||||
}
|
||||
|
||||
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
|
||||
// Blocking request. The result will be sent via msgHandler callback function
|
||||
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
operation := func() error {
|
||||
log.Debugf("management connection state %v", c.conn.GetState())
|
||||
connState := c.conn.GetState()
|
||||
|
||||
if connState == connectivity.Shutdown {
|
||||
return backoff.Permanent(fmt.Errorf("connection to management has been shut down"))
|
||||
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
|
||||
c.conn.WaitForStateChange(ctx, connState)
|
||||
return fmt.Errorf("connection to management is not ready and in %s state", connState)
|
||||
}
|
||||
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf(errMsgMgmtPublicKey, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler)
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, defaultBackoff(ctx))
|
||||
if err != nil {
|
||||
log.Warnf("exiting the Management service connection retry loop due to the unrecoverable error: %s", err)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
|
||||
msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
ctx, cancelStream := context.WithCancel(ctx)
|
||||
defer cancelStream()
|
||||
|
||||
stream, err := c.connectToStream(ctx, serverPubKey, sysInfo)
|
||||
if err != nil {
|
||||
log.Debugf("failed to open Management Service stream: %s", err)
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("connected to the Management Service stream")
|
||||
c.notifyConnected()
|
||||
|
||||
// blocking until error
|
||||
err = c.receiveEvents(stream, serverPubKey, msgHandler)
|
||||
if err != nil {
|
||||
c.notifyDisconnected(err)
|
||||
s, _ := gstatus.FromError(err)
|
||||
switch s.Code() {
|
||||
case codes.PermissionDenied:
|
||||
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
|
||||
case codes.Canceled:
|
||||
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
default:
|
||||
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNetworkMap return with the network map
|
||||
func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error) {
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf("failed getting Management Service public key: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancelStream := context.WithCancel(c.ctx)
|
||||
defer cancelStream()
|
||||
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
|
||||
if err != nil {
|
||||
log.Debugf("failed to open Management Service stream: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = stream.CloseSend()
|
||||
}()
|
||||
|
||||
update, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
log.Debugf("Management stream has been closed by server: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
log.Debugf("disconnected from Management Service sync stream: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decryptedResp := &proto.SyncResponse{}
|
||||
err = encryption.DecryptMessage(*serverPubKey, c.key, update.Body, decryptedResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed decrypting update message from Management Service: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if decryptedResp.GetNetworkMap() == nil {
|
||||
return nil, fmt.Errorf("invalid msg, required network map")
|
||||
}
|
||||
|
||||
return decryptedResp.GetNetworkMap(), nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
|
||||
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)}
|
||||
|
||||
myPrivateKey := c.key
|
||||
myPublicKey := myPrivateKey.PublicKey()
|
||||
|
||||
encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed encrypting message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
|
||||
sync, err := c.realClient.Sync(ctx, syncReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sync, nil
|
||||
}
|
||||
|
||||
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
for {
|
||||
update, err := stream.Recv()
|
||||
if err == io.EOF {
|
||||
log.Debugf("Management stream has been closed by server: %s", err)
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
log.Debugf("disconnected from Management Service sync stream: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("got an update message from Management Service")
|
||||
decryptedResp := &proto.SyncResponse{}
|
||||
err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed decrypting update message from Management Service: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := msgHandler(decryptedResp); err != nil {
|
||||
log.Errorf("failed handling an update message received from Management Service: %v", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetServerPublicKey returns server's WireGuard public key (used later for encrypting messages sent to the server)
|
||||
func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
if !c.ready() {
|
||||
return nil, errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{})
|
||||
if err != nil {
|
||||
log.Errorf("failed while getting Management Service public key: %v", err)
|
||||
return nil, fmt.Errorf("failed while getting Management Service public key")
|
||||
}
|
||||
|
||||
serverKey, err := wgtypes.ParseKey(resp.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &serverKey, nil
|
||||
}
|
||||
|
||||
// IsHealthy probes the gRPC connection and returns false on errors
|
||||
func (c *GrpcClient) IsHealthy() bool {
|
||||
switch c.conn.GetState() {
|
||||
case connectivity.TransientFailure:
|
||||
return false
|
||||
case connectivity.Connecting:
|
||||
return true
|
||||
case connectivity.Shutdown:
|
||||
return true
|
||||
case connectivity.Idle:
|
||||
case connectivity.Ready:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
|
||||
if err != nil {
|
||||
c.notifyDisconnected(err)
|
||||
log.Warnf("health check returned: %s", err)
|
||||
return false
|
||||
}
|
||||
c.notifyConnected()
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) {
|
||||
if !c.ready() {
|
||||
return nil, errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
loginReq, err := encryption.EncryptMessage(serverKey, c.key, req)
|
||||
if err != nil {
|
||||
log.Errorf("failed to encrypt message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp *proto.EncryptedMessage
|
||||
operation := func() error {
|
||||
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
resp, err = c.realClient.Login(mgmCtx, &proto.EncryptedMessage{
|
||||
WgPubKey: c.key.PublicKey().String(),
|
||||
Body: loginReq,
|
||||
})
|
||||
if err != nil {
|
||||
// retry only on context canceled
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
return err
|
||||
}
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err = backoff.Retry(operation, nbgrpc.Backoff(c.ctx))
|
||||
if err != nil {
|
||||
log.Errorf("failed to login to Management Service: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginResp := &proto.LoginResponse{}
|
||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp)
|
||||
if err != nil {
|
||||
log.Errorf("failed to decrypt login response: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
|
||||
// Takes care of encrypting and decrypting messages.
|
||||
// This method will also collect system info and send it with the request (e.g. hostname, os, etc)
|
||||
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
keys := &proto.PeerKeys{
|
||||
SshPubKey: pubSSHKey,
|
||||
WgPubKey: []byte(c.key.PublicKey().String()),
|
||||
}
|
||||
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||
}
|
||||
|
||||
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
|
||||
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
keys := &proto.PeerKeys{
|
||||
SshPubKey: pubSSHKey,
|
||||
WgPubKey: []byte(c.key.PublicKey().String()),
|
||||
}
|
||||
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlow returns a device authorization flow information.
|
||||
// It also takes care of encrypting and decrypting messages.
|
||||
func (c *GrpcClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf("no connection to management in order to get device authorization flow")
|
||||
}
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
|
||||
defer cancel()
|
||||
|
||||
message := &proto.DeviceAuthorizationFlowRequest{}
|
||||
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.realClient.GetDeviceAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
|
||||
WgPubKey: c.key.PublicKey().String(),
|
||||
Body: encryptedMSG},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flowInfoResp := &proto.DeviceAuthorizationFlow{}
|
||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
|
||||
if err != nil {
|
||||
errWithMSG := fmt.Errorf("failed to decrypt device authorization flow message: %s", err)
|
||||
log.Error(errWithMSG)
|
||||
return nil, errWithMSG
|
||||
}
|
||||
|
||||
return flowInfoResp, nil
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlow returns a pkce authorization flow information.
|
||||
// It also takes care of encrypting and decrypting messages.
|
||||
func (c *GrpcClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
|
||||
if !c.ready() {
|
||||
return nil, fmt.Errorf("no connection to management in order to get pkce authorization flow")
|
||||
}
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*2)
|
||||
defer cancel()
|
||||
|
||||
message := &proto.PKCEAuthorizationFlowRequest{}
|
||||
encryptedMSG, err := encryption.EncryptMessage(serverKey, c.key, message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.realClient.GetPKCEAuthorizationFlow(mgmCtx, &proto.EncryptedMessage{
|
||||
WgPubKey: c.key.PublicKey().String(),
|
||||
Body: encryptedMSG,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flowInfoResp := &proto.PKCEAuthorizationFlow{}
|
||||
err = encryption.DecryptMessage(serverKey, c.key, resp.Body, flowInfoResp)
|
||||
if err != nil {
|
||||
errWithMSG := fmt.Errorf("failed to decrypt pkce authorization flow message: %s", err)
|
||||
log.Error(errWithMSG)
|
||||
return nil, errWithMSG
|
||||
}
|
||||
|
||||
return flowInfoResp, nil
|
||||
}
|
||||
|
||||
// SyncMeta sends updated system metadata to the Management Service.
|
||||
// It should be used if there is changes on peer posture check after initial sync.
|
||||
func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
|
||||
if !c.ready() {
|
||||
return errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
serverPubKey, err := c.GetServerPublicKey()
|
||||
if err != nil {
|
||||
log.Debugf(errMsgMgmtPublicKey, err)
|
||||
return err
|
||||
}
|
||||
|
||||
syncMetaReq, err := encryption.EncryptMessage(*serverPubKey, c.key, &proto.SyncMetaRequest{Meta: infoToMetaData(sysInfo)})
|
||||
if err != nil {
|
||||
log.Errorf("failed to encrypt message: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, ConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err = c.realClient.SyncMeta(mgmCtx, &proto.EncryptedMessage{
|
||||
WgPubKey: c.key.PublicKey().String(),
|
||||
Body: syncMetaReq,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *GrpcClient) notifyDisconnected(err error) {
|
||||
c.connStateCallbackLock.RLock()
|
||||
defer c.connStateCallbackLock.RUnlock()
|
||||
|
||||
if c.connStateCallback == nil {
|
||||
return
|
||||
}
|
||||
c.connStateCallback.MarkManagementDisconnected(err)
|
||||
}
|
||||
|
||||
func (c *GrpcClient) notifyConnected() {
|
||||
c.connStateCallbackLock.RLock()
|
||||
defer c.connStateCallbackLock.RUnlock()
|
||||
|
||||
if c.connStateCallback == nil {
|
||||
return
|
||||
}
|
||||
c.connStateCallback.MarkManagementConnected()
|
||||
}
|
||||
|
||||
func (c *GrpcClient) Logout() error {
|
||||
serverKey, err := c.GetServerPublicKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get server public key: %w", err)
|
||||
}
|
||||
|
||||
mgmCtx, cancel := context.WithTimeout(c.ctx, time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
message := &proto.Empty{}
|
||||
encryptedMSG, err := encryption.EncryptMessage(*serverKey, c.key, message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt logout message: %w", err)
|
||||
}
|
||||
|
||||
_, err = c.realClient.Logout(mgmCtx, &proto.EncryptedMessage{
|
||||
WgPubKey: c.key.PublicKey().String(),
|
||||
Body: encryptedMSG,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
addresses := make([]*proto.NetworkAddress, 0, len(info.NetworkAddresses))
|
||||
for _, addr := range info.NetworkAddresses {
|
||||
addresses = append(addresses, &proto.NetworkAddress{
|
||||
NetIP: addr.NetIP.String(),
|
||||
Mac: addr.Mac,
|
||||
})
|
||||
}
|
||||
|
||||
files := make([]*proto.File, 0, len(info.Files))
|
||||
for _, file := range info.Files {
|
||||
files = append(files, &proto.File{
|
||||
Path: file.Path,
|
||||
Exist: file.Exist,
|
||||
ProcessIsRunning: file.ProcessIsRunning,
|
||||
})
|
||||
}
|
||||
|
||||
return &proto.PeerSystemMeta{
|
||||
Hostname: info.Hostname,
|
||||
GoOS: info.GoOS,
|
||||
OS: info.OS,
|
||||
Core: info.OSVersion,
|
||||
OSVersion: info.OSVersion,
|
||||
Platform: info.Platform,
|
||||
Kernel: info.Kernel,
|
||||
NetbirdVersion: info.NetbirdVersion,
|
||||
UiVersion: info.UIVersion,
|
||||
KernelVersion: info.KernelVersion,
|
||||
NetworkAddresses: addresses,
|
||||
SysSerialNumber: info.SystemSerialNumber,
|
||||
SysManufacturer: info.SystemManufacturer,
|
||||
SysProductName: info.SystemProductName,
|
||||
Environment: &proto.Environment{
|
||||
Cloud: info.Environment.Cloud,
|
||||
Platform: info.Environment.Platform,
|
||||
},
|
||||
Files: files,
|
||||
|
||||
Flags: &proto.Flags{
|
||||
RosenpassEnabled: info.RosenpassEnabled,
|
||||
RosenpassPermissive: info.RosenpassPermissive,
|
||||
ServerSSHAllowed: info.ServerSSHAllowed,
|
||||
|
||||
DisableClientRoutes: info.DisableClientRoutes,
|
||||
DisableServerRoutes: info.DisableServerRoutes,
|
||||
DisableDNS: info.DisableDNS,
|
||||
DisableFirewall: info.DisableFirewall,
|
||||
BlockLANAccess: info.BlockLANAccess,
|
||||
BlockInbound: info.BlockInbound,
|
||||
|
||||
LazyConnectionEnabled: info.LazyConnectionEnabled,
|
||||
},
|
||||
}
|
||||
}
|
||||
95
shared/management/client/mock.go
Normal file
95
shared/management/client/mock.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type MockClient struct {
|
||||
CloseFunc func() error
|
||||
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||
GetServerPublicKeyFunc func() (*wgtypes.Key, error)
|
||||
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
|
||||
SyncMetaFunc func(sysInfo *system.Info) error
|
||||
LogoutFunc func() error
|
||||
}
|
||||
|
||||
func (m *MockClient) IsHealthy() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *MockClient) Close() error {
|
||||
if m.CloseFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return m.CloseFunc()
|
||||
}
|
||||
|
||||
func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
|
||||
if m.SyncFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return m.SyncFunc(ctx, sysInfo, msgHandler)
|
||||
}
|
||||
|
||||
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
|
||||
if m.GetServerPublicKeyFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.GetServerPublicKeyFunc()
|
||||
}
|
||||
|
||||
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
if m.RegisterFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey, dnsLabels)
|
||||
}
|
||||
|
||||
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error) {
|
||||
if m.LoginFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.LoginFunc(serverKey, info, sshKey, dnsLabels)
|
||||
}
|
||||
|
||||
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {
|
||||
if m.GetDeviceAuthorizationFlowFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.GetDeviceAuthorizationFlowFunc(serverKey)
|
||||
}
|
||||
|
||||
func (m *MockClient) GetPKCEAuthorizationFlow(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error) {
|
||||
if m.GetPKCEAuthorizationFlowFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.GetPKCEAuthorizationFlow(serverKey)
|
||||
}
|
||||
|
||||
// GetNetworkMap mock implementation of GetNetworkMap from mgm.Client interface
|
||||
func (m *MockClient) GetNetworkMap(_ *system.Info) (*proto.NetworkMap, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockClient) SyncMeta(sysInfo *system.Info) error {
|
||||
if m.SyncMetaFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return m.SyncMetaFunc(sysInfo)
|
||||
}
|
||||
|
||||
func (m *MockClient) Logout() error {
|
||||
if m.LogoutFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return m.LogoutFunc()
|
||||
}
|
||||
60
shared/management/client/rest/accounts.go
Normal file
60
shared/management/client/rest/accounts.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// AccountsAPI APIs for accounts, do not use directly
|
||||
type AccountsAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List list all accounts, only returns one account always
|
||||
// See more: https://docs.netbird.io/api/resources/accounts#list-all-accounts
|
||||
func (a *AccountsAPI) List(ctx context.Context) ([]api.Account, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/accounts", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Account](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Update update account settings
|
||||
// See more: https://docs.netbird.io/api/resources/accounts#update-an-account
|
||||
func (a *AccountsAPI) Update(ctx context.Context, accountID string, request api.PutApiAccountsAccountIdJSONRequestBody) (*api.Account, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/accounts/"+accountID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Account](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete account
|
||||
// See more: https://docs.netbird.io/api/resources/accounts#delete-an-account
|
||||
func (a *AccountsAPI) Delete(ctx context.Context, accountID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/accounts/"+accountID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
183
shared/management/client/rest/accounts_test.go
Normal file
183
shared/management/client/rest/accounts_test.go
Normal file
@@ -0,0 +1,183 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testAccount = api.Account{
|
||||
Id: "Test",
|
||||
Settings: api.AccountSettings{
|
||||
Extra: &api.AccountExtraSettings{
|
||||
PeerApprovalEnabled: false,
|
||||
},
|
||||
GroupsPropagationEnabled: ptr(true),
|
||||
JwtGroupsEnabled: ptr(false),
|
||||
PeerInactivityExpiration: 7,
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerLoginExpiration: 24,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
RegularUsersViewBlocked: false,
|
||||
RoutingPeerDnsResolutionEnabled: ptr(false),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func TestAccounts_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.Account{testAccount})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Accounts.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testAccount, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccounts_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/accounts", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Accounts.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccounts_List_ConnErr(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
ret, err := c.Accounts.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "404")
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccounts_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiAccountsAccountIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, true, *req.Settings.RoutingPeerDnsResolutionEnabled)
|
||||
retBytes, _ := json.Marshal(testAccount)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Accounts.Update(context.Background(), "Test", api.PutApiAccountsAccountIdJSONRequestBody{
|
||||
Settings: api.AccountSettings{
|
||||
RoutingPeerDnsResolutionEnabled: ptr(true),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testAccount, *ret)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestAccounts_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Accounts.Update(context.Background(), "Test", api.PutApiAccountsAccountIdJSONRequestBody{
|
||||
Settings: api.AccountSettings{
|
||||
RoutingPeerDnsResolutionEnabled: ptr(true),
|
||||
},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccounts_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.Accounts.Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccounts_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/accounts/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.Accounts.Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccounts_Integration_List(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
accounts, err := c.Accounts.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, accounts, 1)
|
||||
assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", accounts[0].Id)
|
||||
assert.Equal(t, false, accounts[0].Settings.Extra.PeerApprovalEnabled)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccounts_Integration_Update(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
accounts, err := c.Accounts.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, accounts, 1)
|
||||
accounts[0].Settings.JwtAllowGroups = ptr([]string{"test"})
|
||||
account, err := c.Accounts.Update(context.Background(), accounts[0].Id, api.AccountRequest{
|
||||
Settings: accounts[0].Settings,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accounts[0].Id, account.Id)
|
||||
assert.Equal(t, []string{"test"}, *account.Settings.JwtAllowGroups)
|
||||
})
|
||||
}
|
||||
|
||||
// Account deletion on MySQL and PostgreSQL databases causes unknown errors
|
||||
// func TestAccounts_Integration_Delete(t *testing.T) {
|
||||
// withBlackBoxServer(t, func(c *rest.Client) {
|
||||
// accounts, err := c.Accounts.List(context.Background())
|
||||
// require.NoError(t, err)
|
||||
// assert.Len(t, accounts, 1)
|
||||
// err = c.Accounts.Delete(context.Background(), accounts[0].Id)
|
||||
// require.NoError(t, err)
|
||||
// _, err = c.Accounts.List(context.Background())
|
||||
// assert.Error(t, err)
|
||||
// })
|
||||
// }
|
||||
172
shared/management/client/rest/client.go
Normal file
172
shared/management/client/rest/client.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
// Client Management service HTTP REST API Client
|
||||
type Client struct {
|
||||
managementURL string
|
||||
authHeader string
|
||||
httpClient HttpClient
|
||||
|
||||
// Accounts NetBird account APIs
|
||||
// see more: https://docs.netbird.io/api/resources/accounts
|
||||
Accounts *AccountsAPI
|
||||
|
||||
// Users NetBird users APIs
|
||||
// see more: https://docs.netbird.io/api/resources/users
|
||||
Users *UsersAPI
|
||||
|
||||
// Tokens NetBird tokens APIs
|
||||
// see more: https://docs.netbird.io/api/resources/tokens
|
||||
Tokens *TokensAPI
|
||||
|
||||
// Peers NetBird peers APIs
|
||||
// see more: https://docs.netbird.io/api/resources/peers
|
||||
Peers *PeersAPI
|
||||
|
||||
// SetupKeys NetBird setup keys APIs
|
||||
// see more: https://docs.netbird.io/api/resources/setup-keys
|
||||
SetupKeys *SetupKeysAPI
|
||||
|
||||
// Groups NetBird groups APIs
|
||||
// see more: https://docs.netbird.io/api/resources/groups
|
||||
Groups *GroupsAPI
|
||||
|
||||
// Policies NetBird policies APIs
|
||||
// see more: https://docs.netbird.io/api/resources/policies
|
||||
Policies *PoliciesAPI
|
||||
|
||||
// PostureChecks NetBird posture checks APIs
|
||||
// see more: https://docs.netbird.io/api/resources/posture-checks
|
||||
PostureChecks *PostureChecksAPI
|
||||
|
||||
// Networks NetBird networks APIs
|
||||
// see more: https://docs.netbird.io/api/resources/networks
|
||||
Networks *NetworksAPI
|
||||
|
||||
// Routes NetBird routes APIs
|
||||
// see more: https://docs.netbird.io/api/resources/routes
|
||||
Routes *RoutesAPI
|
||||
|
||||
// DNS NetBird DNS APIs
|
||||
// see more: https://docs.netbird.io/api/resources/routes
|
||||
DNS *DNSAPI
|
||||
|
||||
// GeoLocation NetBird Geo Location APIs
|
||||
// see more: https://docs.netbird.io/api/resources/geo-locations
|
||||
GeoLocation *GeoLocationAPI
|
||||
|
||||
// Events NetBird Events APIs
|
||||
// see more: https://docs.netbird.io/api/resources/events
|
||||
Events *EventsAPI
|
||||
}
|
||||
|
||||
// New initialize new Client instance using PAT token
|
||||
func New(managementURL, token string) *Client {
|
||||
return NewWithOptions(
|
||||
WithManagementURL(managementURL),
|
||||
WithPAT(token),
|
||||
)
|
||||
}
|
||||
|
||||
// NewWithBearerToken initialize new Client instance using Bearer token type
|
||||
func NewWithBearerToken(managementURL, token string) *Client {
|
||||
return NewWithOptions(
|
||||
WithManagementURL(managementURL),
|
||||
WithBearerToken(token),
|
||||
)
|
||||
}
|
||||
|
||||
// NewWithOptions initialize new Client instance with options
|
||||
func NewWithOptions(opts ...option) *Client {
|
||||
client := &Client{
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
|
||||
for _, option := range opts {
|
||||
option(client)
|
||||
}
|
||||
|
||||
client.initialize()
|
||||
return client
|
||||
}
|
||||
|
||||
func (c *Client) initialize() {
|
||||
c.Accounts = &AccountsAPI{c}
|
||||
c.Users = &UsersAPI{c}
|
||||
c.Tokens = &TokensAPI{c}
|
||||
c.Peers = &PeersAPI{c}
|
||||
c.SetupKeys = &SetupKeysAPI{c}
|
||||
c.Groups = &GroupsAPI{c}
|
||||
c.Policies = &PoliciesAPI{c}
|
||||
c.PostureChecks = &PostureChecksAPI{c}
|
||||
c.Networks = &NetworksAPI{c}
|
||||
c.Routes = &RoutesAPI{c}
|
||||
c.DNS = &DNSAPI{c}
|
||||
c.GeoLocation = &GeoLocationAPI{c}
|
||||
c.Events = &EventsAPI{c}
|
||||
}
|
||||
|
||||
// NewRequest creates and executes new management API request
|
||||
func (c *Client) NewRequest(ctx context.Context, method, path string, body io.Reader, query map[string]string) (*http.Response, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, method, c.managementURL+path, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", c.authHeader)
|
||||
req.Header.Add("Accept", "application/json")
|
||||
if body != nil {
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
if len(query) != 0 {
|
||||
q := req.URL.Query()
|
||||
for k, v := range query {
|
||||
q.Add(k, v)
|
||||
}
|
||||
req.URL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode > 299 {
|
||||
parsedErr, pErr := parseResponse[util.ErrorResponse](resp)
|
||||
if pErr != nil {
|
||||
|
||||
return nil, pErr
|
||||
}
|
||||
return nil, errors.New(parsedErr.Message)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseResponse[T any](resp *http.Response) (T, error) {
|
||||
var ret T
|
||||
if resp.Body == nil {
|
||||
return ret, fmt.Errorf("Body missing, HTTP Error code %d", resp.StatusCode)
|
||||
}
|
||||
bs, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ret, err
|
||||
}
|
||||
err = json.Unmarshal(bs, &ret)
|
||||
if err != nil {
|
||||
return ret, fmt.Errorf("Error code %d, error unmarshalling body: %w", resp.StatusCode, err)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
34
shared/management/client/rest/client_test.go
Normal file
34
shared/management/client/rest/client_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
)
|
||||
|
||||
func withMockClient(callback func(*rest.Client, *http.ServeMux)) {
|
||||
mux := &http.ServeMux{}
|
||||
server := httptest.NewServer(mux)
|
||||
defer server.Close()
|
||||
c := rest.New(server.URL, "ABC")
|
||||
callback(c, mux)
|
||||
}
|
||||
|
||||
func ptr[T any, PT *T](x T) PT {
|
||||
return &x
|
||||
}
|
||||
|
||||
func withBlackBoxServer(t *testing.T, callback func(*rest.Client)) {
|
||||
t.Helper()
|
||||
handler, _, _ := testing_tools.BuildApiBlackBoxWithDBState(t, "../../server/testdata/store.sql", nil, false)
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
c := rest.New(server.URL, "nbp_apTmlmUXHSC4PKmHwtIZNaGr8eqcVI2gMURp")
|
||||
callback(c)
|
||||
}
|
||||
124
shared/management/client/rest/dns.go
Normal file
124
shared/management/client/rest/dns.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// DNSAPI APIs for DNS Management, do not use directly
|
||||
type DNSAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// ListNameserverGroups list all nameserver groups
|
||||
// See more: https://docs.netbird.io/api/resources/dns#list-all-nameserver-groups
|
||||
func (a *DNSAPI) ListNameserverGroups(ctx context.Context) ([]api.NameserverGroup, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.NameserverGroup](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// GetNameserverGroup get nameserver group info
|
||||
// See more: https://docs.netbird.io/api/resources/dns#retrieve-a-nameserver-group
|
||||
func (a *DNSAPI) GetNameserverGroup(ctx context.Context, nameserverGroupID string) (*api.NameserverGroup, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/nameservers/"+nameserverGroupID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NameserverGroup](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// CreateNameserverGroup create new nameserver group
|
||||
// See more: https://docs.netbird.io/api/resources/dns#create-a-nameserver-group
|
||||
func (a *DNSAPI) CreateNameserverGroup(ctx context.Context, request api.PostApiDnsNameserversJSONRequestBody) (*api.NameserverGroup, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/nameservers", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NameserverGroup](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// UpdateNameserverGroup update nameserver group info
|
||||
// See more: https://docs.netbird.io/api/resources/dns#update-a-nameserver-group
|
||||
func (a *DNSAPI) UpdateNameserverGroup(ctx context.Context, nameserverGroupID string, request api.PutApiDnsNameserversNsgroupIdJSONRequestBody) (*api.NameserverGroup, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/nameservers/"+nameserverGroupID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NameserverGroup](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// DeleteNameserverGroup delete nameserver group
|
||||
// See more: https://docs.netbird.io/api/resources/dns#delete-a-nameserver-group
|
||||
func (a *DNSAPI) DeleteNameserverGroup(ctx context.Context, nameserverGroupID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/nameservers/"+nameserverGroupID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSettings get DNS settings
|
||||
// See more: https://docs.netbird.io/api/resources/dns#retrieve-dns-settings
|
||||
func (a *DNSAPI) GetSettings(ctx context.Context) (*api.DNSSettings, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/settings", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.DNSSettings](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// UpdateSettings update DNS settings
|
||||
// See more: https://docs.netbird.io/api/resources/dns#update-dns-settings
|
||||
func (a *DNSAPI) UpdateSettings(ctx context.Context, request api.PutApiDnsSettingsJSONRequestBody) (*api.DNSSettings, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/settings", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.DNSSettings](resp)
|
||||
return &ret, err
|
||||
}
|
||||
301
shared/management/client/rest/dns_test.go
Normal file
301
shared/management/client/rest/dns_test.go
Normal file
@@ -0,0 +1,301 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testNameserverGroup = api.NameserverGroup{
|
||||
Id: "Test",
|
||||
Name: "wow",
|
||||
}
|
||||
|
||||
testSettings = api.DNSSettings{
|
||||
DisabledManagementGroups: []string{"gone"},
|
||||
}
|
||||
)
|
||||
|
||||
func TestDNSNameserverGroup_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.NameserverGroup{testNameserverGroup})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNS.ListNameserverGroups(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testNameserverGroup, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSNameserverGroup_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNS.ListNameserverGroups(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSNameserverGroup_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testNameserverGroup)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNS.GetNameserverGroup(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNameserverGroup, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSNameserverGroup_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNS.GetNameserverGroup(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSNameserverGroup_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiDnsNameserversJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "weaw", req.Name)
|
||||
retBytes, _ := json.Marshal(testNameserverGroup)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNS.CreateNameserverGroup(context.Background(), api.PostApiDnsNameserversJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNameserverGroup, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSNameserverGroup_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/nameservers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNS.CreateNameserverGroup(context.Background(), api.PostApiDnsNameserversJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSNameserverGroup_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "weaw", req.Name)
|
||||
retBytes, _ := json.Marshal(testNameserverGroup)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNS.UpdateNameserverGroup(context.Background(), "Test", api.PutApiDnsNameserversNsgroupIdJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNameserverGroup, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSNameserverGroup_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNS.UpdateNameserverGroup(context.Background(), "Test", api.PutApiDnsNameserversNsgroupIdJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSNameserverGroup_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.DNS.DeleteNameserverGroup(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSNameserverGroup_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/nameservers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.DNS.DeleteNameserverGroup(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSSettings_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testSettings)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNS.GetSettings(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testSettings, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSSettings_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNS.GetSettings(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSSettings_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiDnsSettingsJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"test"}, req.DisabledManagementGroups)
|
||||
retBytes, _ := json.Marshal(testSettings)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNS.UpdateSettings(context.Background(), api.PutApiDnsSettingsJSONRequestBody{
|
||||
DisabledManagementGroups: []string{"test"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testSettings, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNSSettings_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/dns/settings", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.DNS.UpdateSettings(context.Background(), api.PutApiDnsSettingsJSONRequestBody{
|
||||
DisabledManagementGroups: []string{"test"},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestDNS_Integration(t *testing.T) {
|
||||
nsGroupReq := api.NameserverGroupRequest{
|
||||
Description: "Test",
|
||||
Enabled: true,
|
||||
Domains: []string{},
|
||||
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
|
||||
Name: "test",
|
||||
Nameservers: []api.Nameserver{
|
||||
{
|
||||
Ip: "8.8.8.8",
|
||||
NsType: api.NameserverNsTypeUdp,
|
||||
Port: 53,
|
||||
},
|
||||
},
|
||||
Primary: true,
|
||||
SearchDomainsEnabled: false,
|
||||
}
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
// Create
|
||||
nsGroup, err := c.DNS.CreateNameserverGroup(context.Background(), nsGroupReq)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List
|
||||
nsGroups, err := c.DNS.ListNameserverGroups(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, *nsGroup, nsGroups[0])
|
||||
|
||||
// Update
|
||||
nsGroupReq.Description = "TestUpdate"
|
||||
nsGroup, err = c.DNS.UpdateNameserverGroup(context.Background(), nsGroup.Id, nsGroupReq)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "TestUpdate", nsGroup.Description)
|
||||
|
||||
// Delete
|
||||
err = c.DNS.DeleteNameserverGroup(context.Background(), nsGroup.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
// List again to ensure deletion
|
||||
nsGroups, err = c.DNS.ListNameserverGroups(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nsGroups, 0)
|
||||
})
|
||||
}
|
||||
26
shared/management/client/rest/events.go
Normal file
26
shared/management/client/rest/events.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// EventsAPI APIs for Events, do not use directly
|
||||
type EventsAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List list all events
|
||||
// See more: https://docs.netbird.io/api/resources/events#list-all-events
|
||||
func (a *EventsAPI) List(ctx context.Context) ([]api.Event, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/events", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Event](resp)
|
||||
return ret, err
|
||||
}
|
||||
70
shared/management/client/rest/events_test.go
Normal file
70
shared/management/client/rest/events_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testEvent = api.Event{
|
||||
Activity: "AccountCreate",
|
||||
ActivityCode: api.EventActivityCodeAccountCreate,
|
||||
}
|
||||
)
|
||||
|
||||
func TestEvents_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.Event{testEvent})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Events.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testEvent, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvents_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/events", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Events.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEvents_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
// Do something that would trigger any event
|
||||
_, err := c.SetupKeys.Create(context.Background(), api.CreateSetupKeyRequest{
|
||||
Ephemeral: ptr(true),
|
||||
Name: "TestSetupKey",
|
||||
Type: "reusable",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
events, err := c.Events.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, events)
|
||||
})
|
||||
}
|
||||
40
shared/management/client/rest/geo.go
Normal file
40
shared/management/client/rest/geo.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// GeoLocationAPI APIs for Geo-Location, do not use directly
|
||||
type GeoLocationAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// ListCountries list all country codes
|
||||
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-country-codes
|
||||
func (a *GeoLocationAPI) ListCountries(ctx context.Context) ([]api.Country, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Country](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// ListCountryCities Get a list of all English city names for a given country code
|
||||
// See more: https://docs.netbird.io/api/resources/geo-locations#list-all-city-names-by-country
|
||||
func (a *GeoLocationAPI) ListCountryCities(ctx context.Context, countryCode string) ([]api.City, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/locations/countries/"+countryCode+"/cities", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.City](resp)
|
||||
return ret, err
|
||||
}
|
||||
101
shared/management/client/rest/geo_test.go
Normal file
101
shared/management/client/rest/geo_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testCountry = api.Country{
|
||||
CountryCode: "DE",
|
||||
CountryName: "Germany",
|
||||
}
|
||||
|
||||
testCity = api.City{
|
||||
CityName: "Berlin",
|
||||
GeonameId: 2950158,
|
||||
}
|
||||
)
|
||||
|
||||
func TestGeo_ListCountries_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.Country{testCountry})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.GeoLocation.ListCountries(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testCountry, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeo_ListCountries_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/locations/countries", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.GeoLocation.ListCountries(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeo_ListCountryCities_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.City{testCity})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.GeoLocation.ListCountryCities(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testCity, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeo_ListCountryCities_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/locations/countries/Test/cities", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.GeoLocation.ListCountryCities(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeo_Integration(t *testing.T) {
|
||||
// Blackbox is initialized with empty GeoLocations
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
countries, err := c.GeoLocation.ListCountries(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, countries)
|
||||
|
||||
cities, err := c.GeoLocation.ListCountryCities(context.Background(), "DE")
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, cities)
|
||||
})
|
||||
}
|
||||
92
shared/management/client/rest/groups.go
Normal file
92
shared/management/client/rest/groups.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// GroupsAPI APIs for Groups, do not use directly
|
||||
type GroupsAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List list all groups
|
||||
// See more: https://docs.netbird.io/api/resources/groups#list-all-groups
|
||||
func (a *GroupsAPI) List(ctx context.Context) ([]api.Group, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Group](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get get group info
|
||||
// See more: https://docs.netbird.io/api/resources/groups#retrieve-a-group
|
||||
func (a *GroupsAPI) Get(ctx context.Context, groupID string) (*api.Group, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/groups/"+groupID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Group](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Create create new group
|
||||
// See more: https://docs.netbird.io/api/resources/groups#create-a-group
|
||||
func (a *GroupsAPI) Create(ctx context.Context, request api.PostApiGroupsJSONRequestBody) (*api.Group, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/groups", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Group](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update update group info
|
||||
// See more: https://docs.netbird.io/api/resources/groups#update-a-group
|
||||
func (a *GroupsAPI) Update(ctx context.Context, groupID string, request api.PutApiGroupsGroupIdJSONRequestBody) (*api.Group, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/groups/"+groupID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Group](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete group
|
||||
// See more: https://docs.netbird.io/api/resources/groups#delete-a-group
|
||||
func (a *GroupsAPI) Delete(ctx context.Context, groupID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/groups/"+groupID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
215
shared/management/client/rest/groups_test.go
Normal file
215
shared/management/client/rest/groups_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testGroup = api.Group{
|
||||
Id: "Test",
|
||||
Name: "wow",
|
||||
PeersCount: 0,
|
||||
}
|
||||
)
|
||||
|
||||
func TestGroups_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.Group{testGroup})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Groups.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testGroup, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroups_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Groups.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroups_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testGroup)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Groups.Get(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testGroup, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroups_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Groups.Get(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroups_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiGroupsJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "weaw", req.Name)
|
||||
retBytes, _ := json.Marshal(testGroup)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Groups.Create(context.Background(), api.PostApiGroupsJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testGroup, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroups_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/groups", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Groups.Create(context.Background(), api.PostApiGroupsJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroups_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiGroupsGroupIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "weaw", req.Name)
|
||||
retBytes, _ := json.Marshal(testGroup)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Groups.Update(context.Background(), "Test", api.PutApiGroupsGroupIdJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testGroup, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroups_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Groups.Update(context.Background(), "Test", api.PutApiGroupsGroupIdJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroups_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.Groups.Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroups_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/groups/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.Groups.Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestGroups_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
groups, err := c.Groups.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, groups, 1)
|
||||
|
||||
group, err := c.Groups.Create(context.Background(), api.GroupRequest{
|
||||
Name: "Test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test", group.Name)
|
||||
assert.NotEmpty(t, group.Id)
|
||||
|
||||
group, err = c.Groups.Update(context.Background(), group.Id, api.GroupRequest{
|
||||
Name: "Testnt",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Testnt", group.Name)
|
||||
|
||||
err = c.Groups.Delete(context.Background(), group.Id)
|
||||
require.NoError(t, err)
|
||||
|
||||
groups, err = c.Groups.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, groups, 1)
|
||||
})
|
||||
}
|
||||
48
shared/management/client/rest/impersonation.go
Normal file
48
shared/management/client/rest/impersonation.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Impersonate returns a Client impersonated for a specific account
|
||||
func (c *Client) Impersonate(account string) *Client {
|
||||
client := NewWithOptions(
|
||||
WithManagementURL(c.managementURL),
|
||||
WithAuthHeader(c.authHeader),
|
||||
WithHttpClient(newImpersonatedHttpClient(c, account)),
|
||||
)
|
||||
return client
|
||||
}
|
||||
|
||||
type impersonatedHttpClient struct {
|
||||
baseClient HttpClient
|
||||
account string
|
||||
}
|
||||
|
||||
func newImpersonatedHttpClient(c *Client, account string) *impersonatedHttpClient {
|
||||
if hc, ok := c.httpClient.(*impersonatedHttpClient); ok {
|
||||
hc.account = account
|
||||
return hc
|
||||
}
|
||||
|
||||
return &impersonatedHttpClient{
|
||||
baseClient: c.httpClient,
|
||||
account: account,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *impersonatedHttpClient) Do(req *http.Request) (*http.Response, error) {
|
||||
parsedURL, err := url.Parse(req.URL.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
query.Set("account", c.account)
|
||||
parsedURL.RawQuery = query.Encode()
|
||||
|
||||
req.URL = parsedURL
|
||||
|
||||
return c.baseClient.Do(req)
|
||||
}
|
||||
77
shared/management/client/rest/impersonation_test.go
Normal file
77
shared/management/client/rest/impersonation_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
var (
|
||||
testImpersonatedAccount = api.Account{
|
||||
Id: "ImpersonatedTest",
|
||||
Settings: api.AccountSettings{
|
||||
Extra: &api.AccountExtraSettings{
|
||||
PeerApprovalEnabled: false,
|
||||
},
|
||||
GroupsPropagationEnabled: ptr(true),
|
||||
JwtGroupsEnabled: ptr(false),
|
||||
PeerInactivityExpiration: 7,
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
PeerLoginExpiration: 24,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
RegularUsersViewBlocked: false,
|
||||
RoutingPeerDnsResolutionEnabled: ptr(false),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func TestImpersonation_Peers_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
impersonatedClient := c.Impersonate(testImpersonatedAccount.Id)
|
||||
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, r.URL.Query().Get("account"), testImpersonatedAccount.Id)
|
||||
retBytes, _ := json.Marshal([]api.Peer{testPeer})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := impersonatedClient.Peers.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testPeer, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestImpersonation_Change_Account(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
impersonatedClient := c.Impersonate(testImpersonatedAccount.Id)
|
||||
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, r.URL.Query().Get("account"), testImpersonatedAccount.Id)
|
||||
retBytes, _ := json.Marshal([]api.Peer{testPeer})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
_, err := impersonatedClient.Peers.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
impersonatedClient = impersonatedClient.Impersonate("another-test-account")
|
||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, r.URL.Query().Get("account"), "another-test-account")
|
||||
retBytes, _ := json.Marshal(testPeer)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
_, err = impersonatedClient.Peers.Get(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
276
shared/management/client/rest/networks.go
Normal file
276
shared/management/client/rest/networks.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// NetworksAPI APIs for Networks, do not use directly
|
||||
type NetworksAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List list all networks
|
||||
// See more: https://docs.netbird.io/api/resources/networks#list-all-networks
|
||||
func (a *NetworksAPI) List(ctx context.Context) ([]api.Network, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Network](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get get network info
|
||||
// See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network
|
||||
func (a *NetworksAPI) Get(ctx context.Context, networkID string) (*api.Network, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+networkID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Network](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Create create new network
|
||||
// See more: https://docs.netbird.io/api/resources/networks#create-a-network
|
||||
func (a *NetworksAPI) Create(ctx context.Context, request api.PostApiNetworksJSONRequestBody) (*api.Network, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Network](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update update network
|
||||
// See more: https://docs.netbird.io/api/resources/networks#update-a-network
|
||||
func (a *NetworksAPI) Update(ctx context.Context, networkID string, request api.PutApiNetworksNetworkIdJSONRequestBody) (*api.Network, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+networkID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Network](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete network
|
||||
// See more: https://docs.netbird.io/api/resources/networks#delete-a-network
|
||||
func (a *NetworksAPI) Delete(ctx context.Context, networkID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+networkID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NetworkResourcesAPI APIs for Network Resources, do not use directly
|
||||
type NetworkResourcesAPI struct {
|
||||
c *Client
|
||||
networkID string
|
||||
}
|
||||
|
||||
// Resources APIs for network resources
|
||||
func (a *NetworksAPI) Resources(networkID string) *NetworkResourcesAPI {
|
||||
return &NetworkResourcesAPI{
|
||||
c: a.c,
|
||||
networkID: networkID,
|
||||
}
|
||||
}
|
||||
|
||||
// List list all resources in networks
|
||||
// See more: https://docs.netbird.io/api/resources/networks#list-all-network-resources
|
||||
func (a *NetworkResourcesAPI) List(ctx context.Context) ([]api.NetworkResource, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.NetworkResource](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get get network resource info
|
||||
// See more: https://docs.netbird.io/api/resources/networks#retrieve-a-network-resource
|
||||
func (a *NetworkResourcesAPI) Get(ctx context.Context, networkResourceID string) (*api.NetworkResource, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkResource](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Create create new network resource
|
||||
// See more: https://docs.netbird.io/api/resources/networks#create-a-network-resource
|
||||
func (a *NetworkResourcesAPI) Create(ctx context.Context, request api.PostApiNetworksNetworkIdResourcesJSONRequestBody) (*api.NetworkResource, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/resources", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkResource](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update update network resource
|
||||
// See more: https://docs.netbird.io/api/resources/networks#update-a-network-resource
|
||||
func (a *NetworkResourcesAPI) Update(ctx context.Context, networkResourceID string, request api.PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody) (*api.NetworkResource, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkResource](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete network resource
|
||||
// See more: https://docs.netbird.io/api/resources/networks#delete-a-network-resource
|
||||
func (a *NetworkResourcesAPI) Delete(ctx context.Context, networkResourceID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/resources/"+networkResourceID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NetworkRoutersAPI APIs for Network Routers, do not use directly
|
||||
type NetworkRoutersAPI struct {
|
||||
c *Client
|
||||
networkID string
|
||||
}
|
||||
|
||||
// Routers APIs for network routers
|
||||
func (a *NetworksAPI) Routers(networkID string) *NetworkRoutersAPI {
|
||||
return &NetworkRoutersAPI{
|
||||
c: a.c,
|
||||
networkID: networkID,
|
||||
}
|
||||
}
|
||||
|
||||
// List list all routers in networks
|
||||
// See more: https://docs.netbird.io/api/routers/networks#list-all-network-routers
|
||||
func (a *NetworkRoutersAPI) List(ctx context.Context) ([]api.NetworkRouter, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.NetworkRouter](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get get network router info
|
||||
// See more: https://docs.netbird.io/api/routers/networks#retrieve-a-network-router
|
||||
func (a *NetworkRoutersAPI) Get(ctx context.Context, networkRouterID string) (*api.NetworkRouter, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkRouter](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Create create new network router
|
||||
// See more: https://docs.netbird.io/api/routers/networks#create-a-network-router
|
||||
func (a *NetworkRoutersAPI) Create(ctx context.Context, request api.PostApiNetworksNetworkIdRoutersJSONRequestBody) (*api.NetworkRouter, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/networks/"+a.networkID+"/routers", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkRouter](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update update network router
|
||||
// See more: https://docs.netbird.io/api/routers/networks#update-a-network-router
|
||||
func (a *NetworkRoutersAPI) Update(ctx context.Context, networkRouterID string, request api.PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody) (*api.NetworkRouter, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.NetworkRouter](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete network router
|
||||
// See more: https://docs.netbird.io/api/routers/networks#delete-a-network-router
|
||||
func (a *NetworkRoutersAPI) Delete(ctx context.Context, networkRouterID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/networks/"+a.networkID+"/routers/"+networkRouterID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
594
shared/management/client/rest/networks_test.go
Normal file
594
shared/management/client/rest/networks_test.go
Normal file
@@ -0,0 +1,594 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testNetwork = api.Network{
|
||||
Id: "Test",
|
||||
Name: "wow",
|
||||
}
|
||||
|
||||
testNetworkResource = api.NetworkResource{
|
||||
Description: ptr("meaw"),
|
||||
Id: "awa",
|
||||
}
|
||||
|
||||
testNetworkRouter = api.NetworkRouter{
|
||||
Id: "ouch",
|
||||
}
|
||||
)
|
||||
|
||||
func TestNetworks_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.Network{testNetwork})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testNetwork, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworks_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworks_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testNetwork)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Get(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNetwork, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworks_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Get(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworks_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiNetworksJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "weaw", req.Name)
|
||||
retBytes, _ := json.Marshal(testNetwork)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Create(context.Background(), api.PostApiNetworksJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNetwork, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworks_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Create(context.Background(), api.PostApiNetworksJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworks_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiNetworksNetworkIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "weaw", req.Name)
|
||||
retBytes, _ := json.Marshal(testNetwork)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Update(context.Background(), "Test", api.PutApiNetworksNetworkIdJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNetwork, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworks_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Update(context.Background(), "Test", api.PutApiNetworksNetworkIdJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworks_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.Networks.Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworks_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.Networks.Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworks_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
network, err := c.Networks.Create(context.Background(), api.NetworkRequest{
|
||||
Description: ptr("TestNetwork"),
|
||||
Name: "Test",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Test", network.Name)
|
||||
|
||||
networks, err := c.Networks.List(context.Background())
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, networks)
|
||||
|
||||
network, err = c.Networks.Update(context.Background(), "TestID", api.NetworkRequest{
|
||||
Description: ptr("TestNetwork?"),
|
||||
Name: "Test",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "TestNetwork?", *network.Description)
|
||||
|
||||
err = c.Networks.Delete(context.Background(), "TestID")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkResources_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.NetworkResource{testNetworkResource})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Resources("Meow").List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testNetworkResource, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkResources_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Resources("Meow").List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkResources_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testNetworkResource)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Resources("Meow").Get(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNetworkResource, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkResources_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Resources("Meow").Get(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkResources_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiNetworksNetworkIdResourcesJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "weaw", req.Name)
|
||||
retBytes, _ := json.Marshal(testNetworkResource)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Resources("Meow").Create(context.Background(), api.PostApiNetworksNetworkIdResourcesJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNetworkResource, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkResources_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/resources", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Resources("Meow").Create(context.Background(), api.PostApiNetworksNetworkIdResourcesJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkResources_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "weaw", req.Name)
|
||||
retBytes, _ := json.Marshal(testNetworkResource)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Resources("Meow").Update(context.Background(), "Test", api.PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNetworkResource, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkResources_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Resources("Meow").Update(context.Background(), "Test", api.PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkResources_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.Networks.Resources("Meow").Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkResources_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/resources/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.Networks.Resources("Meow").Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkResources_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
_, err := c.Networks.Resources("TestNetwork").Create(context.Background(), api.NetworkResourceRequest{
|
||||
Address: "test.com",
|
||||
Description: ptr("Description"),
|
||||
Enabled: false,
|
||||
Groups: []string{"test"},
|
||||
Name: "test",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = c.Networks.Resources("TestNetwork").List(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = c.Networks.Resources("TestNetwork").Get(context.Background(), "TestResource")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = c.Networks.Resources("TestNetwork").Update(context.Background(), "TestResource", api.NetworkResourceRequest{
|
||||
Address: "testnt.com",
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = c.Networks.Resources("TestNetwork").Delete(context.Background(), "TestResource")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkRouters_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.NetworkRouter{testNetworkRouter})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Routers("Meow").List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testNetworkRouter, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkRouters_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Routers("Meow").List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkRouters_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testNetworkRouter)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Routers("Meow").Get(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNetworkRouter, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkRouters_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Routers("Meow").Get(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkRouters_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiNetworksNetworkIdRoutersJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test", *req.Peer)
|
||||
retBytes, _ := json.Marshal(testNetworkRouter)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Routers("Meow").Create(context.Background(), api.PostApiNetworksNetworkIdRoutersJSONRequestBody{
|
||||
Peer: ptr("test"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNetworkRouter, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkRouters_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/routers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Routers("Meow").Create(context.Background(), api.PostApiNetworksNetworkIdRoutersJSONRequestBody{
|
||||
Peer: ptr("test"),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkRouters_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "test", *req.Peer)
|
||||
retBytes, _ := json.Marshal(testNetworkRouter)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Routers("Meow").Update(context.Background(), "Test", api.PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody{
|
||||
Peer: ptr("test"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testNetworkRouter, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkRouters_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Networks.Routers("Meow").Update(context.Background(), "Test", api.PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody{
|
||||
Peer: ptr("test"),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkRouters_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.Networks.Routers("Meow").Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkRouters_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/networks/Meow/routers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.Networks.Routers("Meow").Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestNetworkRouters_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
_, err := c.Networks.Routers("TestNetwork").Create(context.Background(), api.NetworkRouterRequest{
|
||||
Enabled: false,
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
PeerGroups: ptr([]string{"test"}),
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = c.Networks.Routers("TestNetwork").List(context.Background())
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = c.Networks.Routers("TestNetwork").Get(context.Background(), "TestRouter")
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = c.Networks.Routers("TestNetwork").Update(context.Background(), "TestRouter", api.NetworkRouterRequest{
|
||||
Enabled: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = c.Networks.Routers("TestNetwork").Delete(context.Background(), "TestRouter")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
44
shared/management/client/rest/options.go
Normal file
44
shared/management/client/rest/options.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package rest
|
||||
|
||||
import "net/http"
|
||||
|
||||
// option modifier for creation of Client
|
||||
type option func(*Client)
|
||||
|
||||
// HTTPClient interface for HTTP client
|
||||
type HttpClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// WithHTTPClient overrides HTTPClient used
|
||||
func WithHttpClient(client HttpClient) option {
|
||||
return func(c *Client) {
|
||||
c.httpClient = client
|
||||
}
|
||||
}
|
||||
|
||||
// WithBearerToken uses provided bearer token acquired from SSO for authentication
|
||||
func WithBearerToken(token string) option {
|
||||
return WithAuthHeader("Bearer " + token)
|
||||
}
|
||||
|
||||
// WithPAT uses provided Personal Access Token
|
||||
// (created from NetBird Management Dashboard) for authentication
|
||||
func WithPAT(token string) option {
|
||||
return WithAuthHeader("Token " + token)
|
||||
}
|
||||
|
||||
// WithManagementURL overrides target NetBird Management server
|
||||
func WithManagementURL(url string) option {
|
||||
return func(c *Client) {
|
||||
c.managementURL = url
|
||||
}
|
||||
}
|
||||
|
||||
// WithAuthHeader overrides auth header completely, this should generally not be used
|
||||
// and WithBearerToken or WithPAT should be used instead
|
||||
func WithAuthHeader(value string) option {
|
||||
return func(c *Client) {
|
||||
c.authHeader = value
|
||||
}
|
||||
}
|
||||
108
shared/management/client/rest/peers.go
Normal file
108
shared/management/client/rest/peers.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// PeersAPI APIs for peers, do not use directly
|
||||
type PeersAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// PeersListOption options for Peers List API
|
||||
type PeersListOption func() (string, string)
|
||||
|
||||
func PeerNameFilter(name string) PeersListOption {
|
||||
return func() (string, string) {
|
||||
return "name", name
|
||||
}
|
||||
}
|
||||
|
||||
func PeerIPFilter(ip string) PeersListOption {
|
||||
return func() (string, string) {
|
||||
return "ip", ip
|
||||
}
|
||||
}
|
||||
|
||||
// List list all peers
|
||||
// See more: https://docs.netbird.io/api/resources/peers#list-all-peers
|
||||
func (a *PeersAPI) List(ctx context.Context, opts ...PeersListOption) ([]api.Peer, error) {
|
||||
query := make(map[string]string)
|
||||
for _, o := range opts {
|
||||
k, v := o()
|
||||
query[k] = v
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers", nil, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Peer](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get retrieve a peer
|
||||
// See more: https://docs.netbird.io/api/resources/peers#retrieve-a-peer
|
||||
func (a *PeersAPI) Get(ctx context.Context, peerID string) (*api.Peer, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Peer](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update update information for a peer
|
||||
// See more: https://docs.netbird.io/api/resources/peers#update-a-peer
|
||||
func (a *PeersAPI) Update(ctx context.Context, peerID string, request api.PutApiPeersPeerIdJSONRequestBody) (*api.Peer, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/peers/"+peerID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Peer](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete a peer
|
||||
// See more: https://docs.netbird.io/api/resources/peers#delete-a-peer
|
||||
func (a *PeersAPI) Delete(ctx context.Context, peerID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/peers/"+peerID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListAccessiblePeers list all peers that the specified peer can connect to within the network
|
||||
// See more: https://docs.netbird.io/api/resources/peers#list-accessible-peers
|
||||
func (a *PeersAPI) ListAccessiblePeers(ctx context.Context, peerID string) ([]api.Peer, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/peers/"+peerID+"/accessible-peers", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Peer](resp)
|
||||
return ret, err
|
||||
}
|
||||
212
shared/management/client/rest/peers_test.go
Normal file
212
shared/management/client/rest/peers_test.go
Normal file
@@ -0,0 +1,212 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testPeer = api.Peer{
|
||||
ApprovalRequired: false,
|
||||
Connected: false,
|
||||
ConnectionIp: "127.0.0.1",
|
||||
DnsLabel: "test",
|
||||
Id: "Test",
|
||||
}
|
||||
)
|
||||
|
||||
func TestPeers_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.Peer{testPeer})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Peers.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testPeer, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeers_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Peers.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeers_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testPeer)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Peers.Get(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testPeer, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeers_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Peers.Get(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeers_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiPeersPeerIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, true, req.InactivityExpirationEnabled)
|
||||
retBytes, _ := json.Marshal(testPeer)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Peers.Update(context.Background(), "Test", api.PutApiPeersPeerIdJSONRequestBody{
|
||||
InactivityExpirationEnabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testPeer, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeers_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Peers.Update(context.Background(), "Test", api.PutApiPeersPeerIdJSONRequestBody{
|
||||
InactivityExpirationEnabled: false,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeers_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.Peers.Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeers_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.Peers.Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeers_ListAccessiblePeers_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.Peer{testPeer})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Peers.ListAccessiblePeers(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testPeer, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeers_ListAccessiblePeers_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/peers/Test/accessible-peers", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Peers.ListAccessiblePeers(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeers_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
peers, err := c.Peers.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, peers)
|
||||
|
||||
filteredPeers, err := c.Peers.List(context.Background(), rest.PeerIPFilter("192.168.10.0"))
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, filteredPeers)
|
||||
|
||||
peer, err := c.Peers.Get(context.Background(), peers[0].Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peers[0].Id, peer.Id)
|
||||
|
||||
peer, err = c.Peers.Update(context.Background(), peer.Id, api.PeerRequest{
|
||||
LoginExpirationEnabled: true,
|
||||
Name: "Test",
|
||||
SshEnabled: false,
|
||||
ApprovalRequired: ptr(false),
|
||||
InactivityExpirationEnabled: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, true, peer.LoginExpirationEnabled)
|
||||
|
||||
accessiblePeers, err := c.Peers.ListAccessiblePeers(context.Background(), peer.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, accessiblePeers)
|
||||
|
||||
err = c.Peers.Delete(context.Background(), peer.Id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
96
shared/management/client/rest/policies.go
Normal file
96
shared/management/client/rest/policies.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// PoliciesAPI APIs for Policies, do not use directly
|
||||
type PoliciesAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List list all policies
|
||||
// See more: https://docs.netbird.io/api/resources/policies#list-all-policies
|
||||
func (a *PoliciesAPI) List(ctx context.Context) ([]api.Policy, error) {
|
||||
path := "/api/policies"
|
||||
|
||||
resp, err := a.c.NewRequest(ctx, "GET", path, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Policy](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get get policy info
|
||||
// See more: https://docs.netbird.io/api/resources/policies#retrieve-a-policy
|
||||
func (a *PoliciesAPI) Get(ctx context.Context, policyID string) (*api.Policy, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/policies/"+policyID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Policy](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Create create new policy
|
||||
// See more: https://docs.netbird.io/api/resources/policies#create-a-policy
|
||||
func (a *PoliciesAPI) Create(ctx context.Context, request api.PostApiPoliciesJSONRequestBody) (*api.Policy, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/policies", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Policy](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update update policy info
|
||||
// See more: https://docs.netbird.io/api/resources/policies#update-a-policy
|
||||
func (a *PoliciesAPI) Update(ctx context.Context, policyID string, request api.PutApiPoliciesPolicyIdJSONRequestBody) (*api.Policy, error) {
|
||||
path := "/api/policies/" + policyID
|
||||
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", path, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Policy](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete policy
|
||||
// See more: https://docs.netbird.io/api/resources/policies#delete-a-policy
|
||||
func (a *PoliciesAPI) Delete(ctx context.Context, policyID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/policies/"+policyID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
241
shared/management/client/rest/policies_test.go
Normal file
241
shared/management/client/rest/policies_test.go
Normal file
@@ -0,0 +1,241 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testPolicy = api.Policy{
|
||||
Name: "wow",
|
||||
Id: ptr("Test"),
|
||||
Enabled: false,
|
||||
}
|
||||
)
|
||||
|
||||
func TestPolicies_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.Policy{testPolicy})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Policies.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testPolicy, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicies_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Policies.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicies_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testPolicy)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Policies.Get(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testPolicy, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicies_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Policies.Get(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicies_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiPoliciesPolicyIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "weaw", req.Name)
|
||||
retBytes, _ := json.Marshal(testPolicy)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Policies.Create(context.Background(), api.PostApiPoliciesJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testPolicy, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicies_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/policies", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Policies.Create(context.Background(), api.PostApiPoliciesJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicies_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiPoliciesPolicyIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "weaw", req.Name)
|
||||
retBytes, _ := json.Marshal(testPolicy)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Policies.Update(context.Background(), "Test", api.PutApiPoliciesPolicyIdJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testPolicy, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicies_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Policies.Update(context.Background(), "Test", api.PutApiPoliciesPolicyIdJSONRequestBody{
|
||||
Name: "weaw",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicies_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.Policies.Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicies_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/policies/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.Policies.Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPolicies_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
policies, err := c.Policies.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, policies)
|
||||
|
||||
policy, err := c.Policies.Get(context.Background(), *policies[0].Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, *policies[0].Id, *policy.Id)
|
||||
|
||||
policy, err = c.Policies.Update(context.Background(), *policy.Id, api.PolicyCreate{
|
||||
Description: ptr("Test Policy"),
|
||||
Enabled: false,
|
||||
Name: "Test",
|
||||
Rules: []api.PolicyRuleUpdate{
|
||||
{
|
||||
Action: api.PolicyRuleUpdateAction(policy.Rules[0].Action),
|
||||
Bidirectional: true,
|
||||
Description: ptr("Test Policy"),
|
||||
Sources: ptr([]string{(*policy.Rules[0].Sources)[0].Id}),
|
||||
Destinations: ptr([]string{(*policy.Rules[0].Destinations)[0].Id}),
|
||||
Enabled: false,
|
||||
Protocol: api.PolicyRuleUpdateProtocolAll,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: nil,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test Policy", *policy.Rules[0].Description)
|
||||
|
||||
policy, err = c.Policies.Create(context.Background(), api.PolicyUpdate{
|
||||
Description: ptr("Test Policy 2"),
|
||||
Enabled: false,
|
||||
Name: "Test",
|
||||
Rules: []api.PolicyRuleUpdate{
|
||||
{
|
||||
Action: api.PolicyRuleUpdateAction(policy.Rules[0].Action),
|
||||
Bidirectional: true,
|
||||
Description: ptr("Test Policy 2"),
|
||||
Sources: ptr([]string{(*policy.Rules[0].Sources)[0].Id}),
|
||||
Destinations: ptr([]string{(*policy.Rules[0].Destinations)[0].Id}),
|
||||
Enabled: false,
|
||||
Protocol: api.PolicyRuleUpdateProtocolAll,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: nil,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.Policies.Delete(context.Background(), *policy.Id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
92
shared/management/client/rest/posturechecks.go
Normal file
92
shared/management/client/rest/posturechecks.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// PostureChecksAPI APIs for PostureChecks, do not use directly
|
||||
type PostureChecksAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List list all posture checks
|
||||
// See more: https://docs.netbird.io/api/resources/posture-checks#list-all-posture-checks
|
||||
func (a *PostureChecksAPI) List(ctx context.Context) ([]api.PostureCheck, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.PostureCheck](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get get posture check info
|
||||
// See more: https://docs.netbird.io/api/resources/posture-checks#retrieve-a-posture-check
|
||||
func (a *PostureChecksAPI) Get(ctx context.Context, postureCheckID string) (*api.PostureCheck, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/posture-checks/"+postureCheckID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.PostureCheck](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Create create new posture check
|
||||
// See more: https://docs.netbird.io/api/resources/posture-checks#create-a-posture-check
|
||||
func (a *PostureChecksAPI) Create(ctx context.Context, request api.PostApiPostureChecksJSONRequestBody) (*api.PostureCheck, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/posture-checks", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.PostureCheck](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update update posture check info
|
||||
// See more: https://docs.netbird.io/api/resources/posture-checks#update-a-posture-check
|
||||
func (a *PostureChecksAPI) Update(ctx context.Context, postureCheckID string, request api.PutApiPostureChecksPostureCheckIdJSONRequestBody) (*api.PostureCheck, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/posture-checks/"+postureCheckID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.PostureCheck](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete posture check
|
||||
// See more: https://docs.netbird.io/api/resources/posture-checks#delete-a-posture-check
|
||||
func (a *PostureChecksAPI) Delete(ctx context.Context, postureCheckID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/posture-checks/"+postureCheckID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
233
shared/management/client/rest/posturechecks_test.go
Normal file
233
shared/management/client/rest/posturechecks_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testPostureCheck = api.PostureCheck{
|
||||
Id: "Test",
|
||||
Name: "wow",
|
||||
}
|
||||
)
|
||||
|
||||
func TestPostureChecks_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.PostureCheck{testPostureCheck})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.PostureChecks.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testPostureCheck, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostureChecks_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.PostureChecks.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostureChecks_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testPostureCheck)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.PostureChecks.Get(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testPostureCheck, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostureChecks_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.PostureChecks.Get(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostureChecks_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostureCheckUpdate
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "weaw", req.Name)
|
||||
retBytes, _ := json.Marshal(testPostureCheck)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{
|
||||
Name: "weaw",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testPostureCheck, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostureChecks_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/posture-checks", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{
|
||||
Name: "weaw",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostureChecks_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostureCheckUpdate
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "weaw", req.Name)
|
||||
retBytes, _ := json.Marshal(testPostureCheck)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.PostureChecks.Update(context.Background(), "Test", api.PostureCheckUpdate{
|
||||
Name: "weaw",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testPostureCheck, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostureChecks_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.PostureChecks.Update(context.Background(), "Test", api.PostureCheckUpdate{
|
||||
Name: "weaw",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostureChecks_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.PostureChecks.Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostureChecks_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/posture-checks/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.PostureChecks.Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostureChecks_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
check, err := c.PostureChecks.Create(context.Background(), api.PostureCheckUpdate{
|
||||
Name: "Test",
|
||||
Description: "Testing",
|
||||
Checks: &api.Checks{
|
||||
OsVersionCheck: &api.OSVersionCheck{
|
||||
Windows: &api.MinKernelVersionCheck{
|
||||
MinKernelVersion: "0.0.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test", check.Name)
|
||||
|
||||
checks, err := c.PostureChecks.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, checks, 1)
|
||||
|
||||
check, err = c.PostureChecks.Update(context.Background(), check.Id, api.PostureCheckUpdate{
|
||||
Name: "Tests",
|
||||
Description: "Testings",
|
||||
Checks: &api.Checks{
|
||||
GeoLocationCheck: &api.GeoLocationCheck{
|
||||
Action: api.GeoLocationCheckActionAllow, Locations: []api.Location{
|
||||
{
|
||||
CityName: ptr("Cairo"),
|
||||
CountryCode: "EG",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Testings", *check.Description)
|
||||
|
||||
check, err = c.PostureChecks.Get(context.Background(), check.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Tests", check.Name)
|
||||
|
||||
err = c.PostureChecks.Delete(context.Background(), check.Id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
92
shared/management/client/rest/routes.go
Normal file
92
shared/management/client/rest/routes.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// RoutesAPI APIs for Routes, do not use directly
|
||||
type RoutesAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List list all routes
|
||||
// See more: https://docs.netbird.io/api/resources/routes#list-all-routes
|
||||
func (a *RoutesAPI) List(ctx context.Context) ([]api.Route, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/routes", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.Route](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get get route info
|
||||
// See more: https://docs.netbird.io/api/resources/routes#retrieve-a-route
|
||||
func (a *RoutesAPI) Get(ctx context.Context, routeID string) (*api.Route, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/routes/"+routeID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Route](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Create create new route
|
||||
// See more: https://docs.netbird.io/api/resources/routes#create-a-route
|
||||
func (a *RoutesAPI) Create(ctx context.Context, request api.PostApiRoutesJSONRequestBody) (*api.Route, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/routes", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Route](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update update route info
|
||||
// See more: https://docs.netbird.io/api/resources/routes#update-a-route
|
||||
func (a *RoutesAPI) Update(ctx context.Context, routeID string, request api.PutApiRoutesRouteIdJSONRequestBody) (*api.Route, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/routes/"+routeID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.Route](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete route
|
||||
// See more: https://docs.netbird.io/api/resources/routes#delete-a-route
|
||||
func (a *RoutesAPI) Delete(ctx context.Context, routeID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/routes/"+routeID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
231
shared/management/client/rest/routes_test.go
Normal file
231
shared/management/client/rest/routes_test.go
Normal file
@@ -0,0 +1,231 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testRoute = api.Route{
|
||||
Id: "Test",
|
||||
Domains: ptr([]string{"google.com"}),
|
||||
}
|
||||
)
|
||||
|
||||
func TestRoutes_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.Route{testRoute})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Routes.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testRoute, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoutes_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Routes.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoutes_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testRoute)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Routes.Get(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testRoute, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoutes_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Routes.Get(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoutes_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiRoutesJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "meow", req.Description)
|
||||
retBytes, _ := json.Marshal(testRoute)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Routes.Create(context.Background(), api.PostApiRoutesJSONRequestBody{
|
||||
Description: "meow",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testRoute, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoutes_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/routes", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Routes.Create(context.Background(), api.PostApiRoutesJSONRequestBody{
|
||||
Description: "meow",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoutes_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiRoutesRouteIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "meow", req.Description)
|
||||
retBytes, _ := json.Marshal(testRoute)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Routes.Update(context.Background(), "Test", api.PutApiRoutesRouteIdJSONRequestBody{
|
||||
Description: "meow",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testRoute, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoutes_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Routes.Update(context.Background(), "Test", api.PutApiRoutesRouteIdJSONRequestBody{
|
||||
Description: "meow",
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoutes_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.Routes.Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoutes_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/routes/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.Routes.Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestRoutes_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
route, err := c.Routes.Create(context.Background(), api.RouteRequest{
|
||||
Description: "Meow",
|
||||
Enabled: false,
|
||||
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
|
||||
PeerGroups: ptr([]string{"cs1tnh0hhcjnqoiuebeg"}),
|
||||
Domains: ptr([]string{"google.com"}),
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
KeepRoute: false,
|
||||
NetworkId: "Test",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test", route.NetworkId)
|
||||
|
||||
routes, err := c.Routes.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, routes, 1)
|
||||
|
||||
route, err = c.Routes.Update(context.Background(), route.Id, api.RouteRequest{
|
||||
Description: "Testings",
|
||||
Enabled: false,
|
||||
Groups: []string{"cs1tnh0hhcjnqoiuebeg"},
|
||||
PeerGroups: ptr([]string{"cs1tnh0hhcjnqoiuebeg"}),
|
||||
Domains: ptr([]string{"google.com"}),
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
KeepRoute: false,
|
||||
NetworkId: "Tests",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Testings", route.Description)
|
||||
|
||||
route, err = c.Routes.Get(context.Background(), route.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Tests", route.NetworkId)
|
||||
|
||||
err = c.Routes.Delete(context.Background(), route.Id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
94
shared/management/client/rest/setupkeys.go
Normal file
94
shared/management/client/rest/setupkeys.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// SetupKeysAPI APIs for Setup keys, do not use directly
|
||||
type SetupKeysAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List list all setup keys
|
||||
// See more: https://docs.netbird.io/api/resources/setup-keys#list-all-setup-keys
|
||||
func (a *SetupKeysAPI) List(ctx context.Context) ([]api.SetupKey, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.SetupKey](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get get setup key info
|
||||
// See more: https://docs.netbird.io/api/resources/setup-keys#retrieve-a-setup-key
|
||||
func (a *SetupKeysAPI) Get(ctx context.Context, setupKeyID string) (*api.SetupKey, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/setup-keys/"+setupKeyID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.SetupKey](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Create generate new Setup Key
|
||||
// See more: https://docs.netbird.io/api/resources/setup-keys#create-a-setup-key
|
||||
func (a *SetupKeysAPI) Create(ctx context.Context, request api.PostApiSetupKeysJSONRequestBody) (*api.SetupKeyClear, error) {
|
||||
path := "/api/setup-keys"
|
||||
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", path, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.SetupKeyClear](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update generate new Setup Key
|
||||
// See more: https://docs.netbird.io/api/resources/setup-keys#update-a-setup-key
|
||||
func (a *SetupKeysAPI) Update(ctx context.Context, setupKeyID string, request api.PutApiSetupKeysKeyIdJSONRequestBody) (*api.SetupKey, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/setup-keys/"+setupKeyID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.SetupKey](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete setup key
|
||||
// See more: https://docs.netbird.io/api/resources/setup-keys#delete-a-setup-key
|
||||
func (a *SetupKeysAPI) Delete(ctx context.Context, setupKeyID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/setup-keys/"+setupKeyID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
232
shared/management/client/rest/setupkeys_test.go
Normal file
232
shared/management/client/rest/setupkeys_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testSetupKey = api.SetupKey{
|
||||
Id: "Test",
|
||||
Name: "wow",
|
||||
AutoGroups: []string{"meow"},
|
||||
Ephemeral: true,
|
||||
}
|
||||
|
||||
testSteupKeyGenerated = api.SetupKeyClear{
|
||||
Id: "Test",
|
||||
Name: "wow",
|
||||
AutoGroups: []string{"meow"},
|
||||
Ephemeral: true,
|
||||
Key: "shhh",
|
||||
}
|
||||
)
|
||||
|
||||
func TestSetupKeys_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.SetupKey{testSetupKey})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.SetupKeys.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testSetupKey, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetupKeys_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.SetupKeys.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetupKeys_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testSetupKey)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.SetupKeys.Get(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testSetupKey, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetupKeys_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.SetupKeys.Get(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetupKeys_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiSetupKeysJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, req.ExpiresIn)
|
||||
retBytes, _ := json.Marshal(testSteupKeyGenerated)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.SetupKeys.Create(context.Background(), api.PostApiSetupKeysJSONRequestBody{
|
||||
ExpiresIn: 5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testSteupKeyGenerated, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetupKeys_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/setup-keys", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.SetupKeys.Create(context.Background(), api.PostApiSetupKeysJSONRequestBody{
|
||||
ExpiresIn: 5,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetupKeys_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiSetupKeysKeyIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, true, req.Revoked)
|
||||
retBytes, _ := json.Marshal(testSetupKey)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.SetupKeys.Update(context.Background(), "Test", api.PutApiSetupKeysKeyIdJSONRequestBody{
|
||||
Revoked: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testSetupKey, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetupKeys_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.SetupKeys.Update(context.Background(), "Test", api.PutApiSetupKeysKeyIdJSONRequestBody{
|
||||
Revoked: true,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetupKeys_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.SetupKeys.Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetupKeys_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/setup-keys/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.SetupKeys.Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestSetupKeys_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
group, err := c.Groups.Create(context.Background(), api.GroupRequest{
|
||||
Name: "Test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
skClear, err := c.SetupKeys.Create(context.Background(), api.CreateSetupKeyRequest{
|
||||
AutoGroups: []string{group.Id},
|
||||
Ephemeral: ptr(false),
|
||||
Name: "test",
|
||||
Type: "reusable",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, true, skClear.Valid)
|
||||
|
||||
keys, err := c.SetupKeys.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, keys, 2)
|
||||
|
||||
sk, err := c.SetupKeys.Update(context.Background(), skClear.Id, api.SetupKeyRequest{
|
||||
Revoked: true,
|
||||
AutoGroups: []string{group.Id},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
sk, err = c.SetupKeys.Get(context.Background(), sk.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, false, sk.Valid)
|
||||
|
||||
err = c.SetupKeys.Delete(context.Background(), sk.Id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
74
shared/management/client/rest/tokens.go
Normal file
74
shared/management/client/rest/tokens.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// TokensAPI APIs for PATs, do not use directly
|
||||
type TokensAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List list user tokens
|
||||
// See more: https://docs.netbird.io/api/resources/tokens#list-all-tokens
|
||||
func (a *TokensAPI) List(ctx context.Context, userID string) ([]api.PersonalAccessToken, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.PersonalAccessToken](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Get get user token info
|
||||
// See more: https://docs.netbird.io/api/resources/tokens#retrieve-a-token
|
||||
func (a *TokensAPI) Get(ctx context.Context, userID, tokenID string) (*api.PersonalAccessToken, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/"+userID+"/tokens/"+tokenID, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.PersonalAccessToken](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Create generate new PAT for user
|
||||
// See more: https://docs.netbird.io/api/resources/tokens#create-a-token
|
||||
func (a *TokensAPI) Create(ctx context.Context, userID string, request api.PostApiUsersUserIdTokensJSONRequestBody) (*api.PersonalAccessTokenGenerated, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/tokens", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.PersonalAccessTokenGenerated](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete user token
|
||||
// See more: https://docs.netbird.io/api/resources/tokens#delete-a-token
|
||||
func (a *TokensAPI) Delete(ctx context.Context, userID, tokenID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID+"/tokens/"+tokenID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
180
shared/management/client/rest/tokens_test.go
Normal file
180
shared/management/client/rest/tokens_test.go
Normal file
@@ -0,0 +1,180 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testToken = api.PersonalAccessToken{
|
||||
Id: "Test",
|
||||
CreatedAt: time.Time{},
|
||||
CreatedBy: "meow",
|
||||
ExpirationDate: time.Time{},
|
||||
LastUsed: nil,
|
||||
Name: "wow",
|
||||
}
|
||||
|
||||
testTokenGenerated = api.PersonalAccessTokenGenerated{
|
||||
PersonalAccessToken: testToken,
|
||||
PlainToken: "shhh",
|
||||
}
|
||||
)
|
||||
|
||||
func TestTokens_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.PersonalAccessToken{testToken})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Tokens.List(context.Background(), "meow")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testToken, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokens_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Tokens.List(context.Background(), "meow")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokens_Get_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testToken)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Tokens.Get(context.Background(), "meow", "Test")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testToken, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokens_Get_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Tokens.Get(context.Background(), "meow", "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokens_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiUsersUserIdTokensJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, req.ExpiresIn)
|
||||
retBytes, _ := json.Marshal(testTokenGenerated)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Tokens.Create(context.Background(), "meow", api.PostApiUsersUserIdTokensJSONRequestBody{
|
||||
ExpiresIn: 5,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testTokenGenerated, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokens_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/meow/tokens", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Tokens.Create(context.Background(), "meow", api.PostApiUsersUserIdTokensJSONRequestBody{
|
||||
ExpiresIn: 5,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokens_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.Tokens.Delete(context.Background(), "meow", "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokens_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/meow/tokens/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.Tokens.Delete(context.Background(), "meow", "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokens_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
tokenClear, err := c.Tokens.Create(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", api.PersonalAccessTokenRequest{
|
||||
Name: "Test",
|
||||
ExpiresIn: 365,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test", tokenClear.PersonalAccessToken.Name)
|
||||
|
||||
tokens, err := c.Tokens.List(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 2)
|
||||
|
||||
token, err := c.Tokens.Get(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", tokenClear.PersonalAccessToken.Id)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Test", token.Name)
|
||||
|
||||
err = c.Tokens.Delete(context.Background(), "a23efe53-63fb-11ec-90d6-0242ac120003", token.Id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
107
shared/management/client/rest/users.go
Normal file
107
shared/management/client/rest/users.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package rest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
)
|
||||
|
||||
// UsersAPI APIs for users, do not use directly
|
||||
type UsersAPI struct {
|
||||
c *Client
|
||||
}
|
||||
|
||||
// List list all users, only returns one user always
|
||||
// See more: https://docs.netbird.io/api/resources/users#list-all-users
|
||||
func (a *UsersAPI) List(ctx context.Context) ([]api.User, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/users", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[[]api.User](resp)
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// Create create user
|
||||
// See more: https://docs.netbird.io/api/resources/users#create-a-user
|
||||
func (a *UsersAPI) Create(ctx context.Context, request api.PostApiUsersJSONRequestBody) (*api.User, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/users", bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.User](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Update update user settings
|
||||
// See more: https://docs.netbird.io/api/resources/users#update-a-user
|
||||
func (a *UsersAPI) Update(ctx context.Context, userID string, request api.PutApiUsersUserIdJSONRequestBody) (*api.User, error) {
|
||||
requestBytes, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := a.c.NewRequest(ctx, "PUT", "/api/users/"+userID, bytes.NewReader(requestBytes), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
ret, err := parseResponse[api.User](resp)
|
||||
return &ret, err
|
||||
}
|
||||
|
||||
// Delete delete user
|
||||
// See more: https://docs.netbird.io/api/resources/users#delete-a-user
|
||||
func (a *UsersAPI) Delete(ctx context.Context, userID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/users/"+userID, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResendInvitation resend user invitation
|
||||
// See more: https://docs.netbird.io/api/resources/users#resend-user-invitation
|
||||
func (a *UsersAPI) ResendInvitation(ctx context.Context, userID string) error {
|
||||
resp, err := a.c.NewRequest(ctx, "POST", "/api/users/"+userID+"/invite", nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Current gets the current user info
|
||||
// See more: https://docs.netbird.io/api/resources/users#retrieve-current-user
|
||||
func (a *UsersAPI) Current(ctx context.Context) (*api.User, error) {
|
||||
resp, err := a.c.NewRequest(ctx, "GET", "/api/users/current", nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
defer resp.Body.Close()
|
||||
}
|
||||
|
||||
ret, err := parseResponse[api.User](resp)
|
||||
return &ret, err
|
||||
}
|
||||
258
shared/management/client/rest/users_test.go
Normal file
258
shared/management/client/rest/users_test.go
Normal file
@@ -0,0 +1,258 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package rest_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/client/rest"
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/util"
|
||||
)
|
||||
|
||||
var (
|
||||
testUser = api.User{
|
||||
Id: "Test",
|
||||
AutoGroups: []string{"test-group"},
|
||||
Email: "test@test.com",
|
||||
IsBlocked: false,
|
||||
IsCurrent: ptr(false),
|
||||
IsServiceUser: ptr(false),
|
||||
Issued: ptr("api"),
|
||||
LastLogin: &time.Time{},
|
||||
Name: "M. Essam",
|
||||
Role: "user",
|
||||
Status: api.UserStatusActive,
|
||||
}
|
||||
)
|
||||
|
||||
func TestUsers_List_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal([]api.User{testUser})
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Users.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, ret, 1)
|
||||
assert.Equal(t, testUser, ret[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsers_List_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Users.List(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsers_Create_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PostApiUsersJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"meow"}, req.AutoGroups)
|
||||
retBytes, _ := json.Marshal(testUser)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Users.Create(context.Background(), api.PostApiUsersJSONRequestBody{
|
||||
AutoGroups: []string{"meow"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testUser, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsers_Create_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Users.Create(context.Background(), api.PostApiUsersJSONRequestBody{
|
||||
AutoGroups: []string{"meow"},
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsers_Update_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "PUT", r.Method)
|
||||
reqBytes, err := io.ReadAll(r.Body)
|
||||
require.NoError(t, err)
|
||||
var req api.PutApiUsersUserIdJSONRequestBody
|
||||
err = json.Unmarshal(reqBytes, &req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, true, req.IsBlocked)
|
||||
retBytes, _ := json.Marshal(testUser)
|
||||
_, err = w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Users.Update(context.Background(), "Test", api.PutApiUsersUserIdJSONRequestBody{
|
||||
IsBlocked: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testUser, *ret)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestUsers_Update_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Users.Update(context.Background(), "Test", api.PutApiUsersUserIdJSONRequestBody{
|
||||
IsBlocked: true,
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Nil(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsers_Delete_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "DELETE", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.Users.Delete(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsers_Delete_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/Test", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.Users.Delete(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsers_ResendInvitation_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
w.WriteHeader(200)
|
||||
})
|
||||
err := c.Users.ResendInvitation(context.Background(), "Test")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsers_ResendInvitation_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/Test/invite", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
|
||||
w.WriteHeader(404)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
err := c.Users.ResendInvitation(context.Background(), "Test")
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "Not found", err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsers_Current_200(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/current", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(testUser)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Users.Current(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testUser, *ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsers_Current_Err(t *testing.T) {
|
||||
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/users/current", func(w http.ResponseWriter, r *http.Request) {
|
||||
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
|
||||
w.WriteHeader(400)
|
||||
_, err := w.Write(retBytes)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
ret, err := c.Users.Current(context.Background())
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, "No", err.Error())
|
||||
assert.Empty(t, ret)
|
||||
})
|
||||
}
|
||||
|
||||
func TestUsers_Integration(t *testing.T) {
|
||||
withBlackBoxServer(t, func(c *rest.Client) {
|
||||
// rest client PAT is owner's
|
||||
current, err := c.Users.Current(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "a23efe53-63fb-11ec-90d6-0242ac120003", current.Id)
|
||||
assert.Equal(t, "owner", current.Role)
|
||||
|
||||
user, err := c.Users.Create(context.Background(), api.UserCreateRequest{
|
||||
AutoGroups: []string{},
|
||||
Email: ptr("test@example.com"),
|
||||
IsServiceUser: true,
|
||||
Name: ptr("Nobody"),
|
||||
Role: "user",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Nobody", user.Name)
|
||||
|
||||
users, err := c.Users.List(context.Background())
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, users)
|
||||
|
||||
user, err = c.Users.Update(context.Background(), user.Id, api.UserRequest{
|
||||
AutoGroups: []string{},
|
||||
Role: "admin",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "admin", user.Role)
|
||||
|
||||
err = c.Users.Delete(context.Background(), user.Id)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
45
shared/management/domain/domain.go
Normal file
45
shared/management/domain/domain.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
// Domain represents a punycode-encoded domain string.
|
||||
// This should only be converted from a string when the string already is in punycode, otherwise use FromString.
|
||||
type Domain string
|
||||
|
||||
// String converts the Domain to a non-punycode string.
|
||||
// For an infallible conversion, use SafeString.
|
||||
func (d Domain) String() (string, error) {
|
||||
unicode, err := idna.ToUnicode(string(d))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return unicode, nil
|
||||
}
|
||||
|
||||
// SafeString converts the Domain to a non-punycode string, falling back to the punycode string if conversion fails.
|
||||
func (d Domain) SafeString() string {
|
||||
str, err := d.String()
|
||||
if err != nil {
|
||||
return string(d)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// PunycodeString returns the punycode representation of the Domain.
|
||||
// This should only be used if a punycode domain is expected but only a string is supported.
|
||||
func (d Domain) PunycodeString() string {
|
||||
return string(d)
|
||||
}
|
||||
|
||||
// FromString creates a Domain from a string, converting it to punycode.
|
||||
func FromString(s string) (Domain, error) {
|
||||
ascii, err := idna.ToASCII(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return Domain(strings.ToLower(ascii)), nil
|
||||
}
|
||||
1
shared/management/domain/go.sum
Normal file
1
shared/management/domain/go.sum
Normal file
@@ -0,0 +1 @@
|
||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||
108
shared/management/domain/list.go
Normal file
108
shared/management/domain/list.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// List is a slice of punycode-encoded domain strings.
|
||||
type List []Domain
|
||||
|
||||
// ToStringList converts a List to a slice of string.
|
||||
func (d List) ToStringList() ([]string, error) {
|
||||
var list []string
|
||||
for _, domain := range d {
|
||||
s, err := domain.String()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, s)
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
// ToPunycodeList converts the List to a slice of Punycode-encoded domain strings.
|
||||
func (d List) ToPunycodeList() []string {
|
||||
var list []string
|
||||
for _, domain := range d {
|
||||
list = append(list, string(domain))
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// ToSafeStringList converts the List to a slice of non-punycode strings.
|
||||
// If a domain cannot be converted, the original string is used.
|
||||
func (d List) ToSafeStringList() []string {
|
||||
var list []string
|
||||
for _, domain := range d {
|
||||
list = append(list, domain.SafeString())
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// String converts List to a comma-separated string.
|
||||
func (d List) String() (string, error) {
|
||||
list, err := d.ToStringList()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.Join(list, ", "), nil
|
||||
}
|
||||
|
||||
// SafeString converts List to a comma-separated non-punycode string.
|
||||
// If a domain cannot be converted, the original string is used.
|
||||
func (d List) SafeString() string {
|
||||
str, err := d.String()
|
||||
if err != nil {
|
||||
return d.PunycodeString()
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// PunycodeString converts the List to a comma-separated string of Punycode-encoded domains.
|
||||
func (d List) PunycodeString() string {
|
||||
return strings.Join(d.ToPunycodeList(), ", ")
|
||||
}
|
||||
|
||||
func (d List) Equal(domains List) bool {
|
||||
if len(d) != len(domains) {
|
||||
return false
|
||||
}
|
||||
|
||||
sort.Slice(d, func(i, j int) bool {
|
||||
return d[i] < d[j]
|
||||
})
|
||||
|
||||
sort.Slice(domains, func(i, j int) bool {
|
||||
return domains[i] < domains[j]
|
||||
})
|
||||
|
||||
for i, domain := range d {
|
||||
if domain != domains[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// FromStringList creates a DomainList from a slice of string.
|
||||
func FromStringList(s []string) (List, error) {
|
||||
var dl List
|
||||
for _, domain := range s {
|
||||
d, err := FromString(domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl = append(dl, d)
|
||||
}
|
||||
return dl, nil
|
||||
}
|
||||
|
||||
// FromPunycodeList creates a List from a slice of Punycode-encoded domain strings.
|
||||
func FromPunycodeList(s []string) List {
|
||||
var dl List
|
||||
for _, domain := range s {
|
||||
dl = append(dl, Domain(strings.ToLower(domain)))
|
||||
}
|
||||
return dl
|
||||
}
|
||||
49
shared/management/domain/list_test.go
Normal file
49
shared/management/domain/list_test.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_EqualReturnsTrueForIdenticalLists(t *testing.T) {
|
||||
list1 := List{"domain1", "domain2", "domain3"}
|
||||
list2 := List{"domain1", "domain2", "domain3"}
|
||||
|
||||
assert.True(t, list1.Equal(list2))
|
||||
}
|
||||
|
||||
func Test_EqualReturnsFalseForDifferentLengths(t *testing.T) {
|
||||
list1 := List{"domain1", "domain2"}
|
||||
list2 := List{"domain1", "domain2", "domain3"}
|
||||
|
||||
assert.False(t, list1.Equal(list2))
|
||||
}
|
||||
|
||||
func Test_EqualReturnsFalseForDifferentElements(t *testing.T) {
|
||||
list1 := List{"domain1", "domain2", "domain3"}
|
||||
list2 := List{"domain1", "domain4", "domain3"}
|
||||
|
||||
assert.False(t, list1.Equal(list2))
|
||||
}
|
||||
|
||||
func Test_EqualReturnsTrueForUnsortedIdenticalLists(t *testing.T) {
|
||||
list1 := List{"domain3", "domain1", "domain2"}
|
||||
list2 := List{"domain1", "domain2", "domain3"}
|
||||
|
||||
assert.True(t, list1.Equal(list2))
|
||||
}
|
||||
|
||||
func Test_EqualReturnsFalseForEmptyAndNonEmptyList(t *testing.T) {
|
||||
list1 := List{}
|
||||
list2 := List{"domain1"}
|
||||
|
||||
assert.False(t, list1.Equal(list2))
|
||||
}
|
||||
|
||||
func Test_EqualReturnsTrueForBothEmptyLists(t *testing.T) {
|
||||
list1 := List{}
|
||||
list2 := List{}
|
||||
|
||||
assert.True(t, list1.Equal(list2))
|
||||
}
|
||||
56
shared/management/domain/validate.go
Normal file
56
shared/management/domain/validate.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const maxDomains = 32
|
||||
|
||||
var domainRegex = regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`)
|
||||
|
||||
// ValidateDomains checks if each domain in the list is valid and returns a punycode-encoded DomainList.
|
||||
func ValidateDomains(domains []string) (List, error) {
|
||||
if len(domains) == 0 {
|
||||
return nil, fmt.Errorf("domains list is empty")
|
||||
}
|
||||
if len(domains) > maxDomains {
|
||||
return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
||||
}
|
||||
|
||||
var domainList List
|
||||
|
||||
for _, d := range domains {
|
||||
// handles length and idna conversion
|
||||
punycode, err := FromString(d)
|
||||
if err != nil {
|
||||
return domainList, fmt.Errorf("convert domain to punycode: %s: %w", d, err)
|
||||
}
|
||||
|
||||
if !domainRegex.MatchString(string(punycode)) {
|
||||
return domainList, fmt.Errorf("invalid domain format: %s", d)
|
||||
}
|
||||
|
||||
domainList = append(domainList, punycode)
|
||||
}
|
||||
return domainList, nil
|
||||
}
|
||||
|
||||
// ValidateDomainsList checks if each domain in the list is valid
|
||||
func ValidateDomainsList(domains []string) error {
|
||||
if len(domains) == 0 {
|
||||
return nil
|
||||
}
|
||||
if len(domains) > maxDomains {
|
||||
return fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains)
|
||||
}
|
||||
|
||||
for _, d := range domains {
|
||||
d := strings.ToLower(d)
|
||||
if !domainRegex.MatchString(d) {
|
||||
return fmt.Errorf("invalid domain format: %s", d)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
185
shared/management/domain/validate_test.go
Normal file
185
shared/management/domain/validate_test.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestValidateDomains(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domains []string
|
||||
expected List
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Empty list",
|
||||
domains: nil,
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Valid ASCII domain",
|
||||
domains: []string{"sub.ex-ample.com"},
|
||||
expected: List{"sub.ex-ample.com"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Unicode domain",
|
||||
domains: []string{"münchen.de"},
|
||||
expected: List{"xn--mnchen-3ya.de"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Valid Unicode, all labels",
|
||||
domains: []string{"中国.中国.中国"},
|
||||
expected: List{"xn--fiqs8s.xn--fiqs8s.xn--fiqs8s"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "With underscores",
|
||||
domains: []string{"_jabber._tcp.gmail.com"},
|
||||
expected: List{"_jabber._tcp.gmail.com"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain format",
|
||||
domains: []string{"-example.com"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain format 2",
|
||||
domains: []string{"example.com-"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple domains valid and invalid",
|
||||
domains: []string{"google.com", "invalid,nbdomain.com", "münchen.de"},
|
||||
expected: List{"google.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Valid wildcard domain",
|
||||
domains: []string{"*.example.com"},
|
||||
expected: List{"*.example.com"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Wildcard with dot domain",
|
||||
domains: []string{".*.example.com"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Wildcard with dot domain",
|
||||
domains: []string{".*.example.com"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid wildcard domain",
|
||||
domains: []string{"a.*.example.com"},
|
||||
expected: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ValidateDomains(tt.domains)
|
||||
assert.Equal(t, tt.wantErr, err != nil)
|
||||
assert.Equal(t, got, tt.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateDomainsList(t *testing.T) {
|
||||
validDomains := make([]string, maxDomains)
|
||||
for i := range maxDomains {
|
||||
validDomains[i] = fmt.Sprintf("example%d.com", i)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domains []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Empty list",
|
||||
domains: nil,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Single valid ASCII domain",
|
||||
domains: []string{"sub.ex-ample.com"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Underscores in labels",
|
||||
domains: []string{"_jabber._tcp.gmail.com"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
// Unlike ValidateDomains (which converts to punycode),
|
||||
// ValidateDomainsStrSlice will fail on non-ASCII domain chars.
|
||||
name: "Unicode domain fails (no punycode conversion)",
|
||||
domains: []string{"münchen.de"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain format - leading dash",
|
||||
domains: []string{"-example.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid domain format - trailing dash",
|
||||
domains: []string{"example-.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Multiple domains with a valid one, then invalid",
|
||||
domains: []string{"google.com", "invalid_domain.com-"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Valid wildcard domain",
|
||||
domains: []string{"*.example.com"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Wildcard with leading dot - invalid",
|
||||
domains: []string{".*.example.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Invalid wildcard with multiple asterisks",
|
||||
domains: []string{"a.*.example.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "Exactly maxDomains items (valid)",
|
||||
domains: validDomains,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Exceeds maxDomains items",
|
||||
domains: append(validDomains, "extra.com"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateDomainsList(tt.domains)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
17
shared/management/proto/generate.sh
Executable file
17
shared/management/proto/generate.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if ! which realpath > /dev/null 2>&1
|
||||
then
|
||||
echo realpath is not installed
|
||||
echo run: brew install coreutils
|
||||
exit 1
|
||||
fi
|
||||
|
||||
old_pwd=$(pwd)
|
||||
script_path=$(dirname $(realpath "$0"))
|
||||
cd "$script_path"
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
||||
protoc -I ./ ./management.proto --go_out=../ --go-grpc_out=../
|
||||
cd "$old_pwd"
|
||||
2
shared/management/proto/go.sum
Normal file
2
shared/management/proto/go.sum
Normal file
@@ -0,0 +1,2 @@
|
||||
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
4519
shared/management/proto/management.pb.go
Normal file
4519
shared/management/proto/management.pb.go
Normal file
File diff suppressed because it is too large
Load Diff
542
shared/management/proto/management.proto
Normal file
542
shared/management/proto/management.proto
Normal file
@@ -0,0 +1,542 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/duration.proto";
|
||||
|
||||
option go_package = "/proto";
|
||||
|
||||
package management;
|
||||
|
||||
service ManagementService {
|
||||
|
||||
// Login logs in peer. In case server returns codes.PermissionDenied this endpoint can be used to register Peer providing LoginRequest.setupKey
|
||||
// Returns encrypted LoginResponse in EncryptedMessage.Body
|
||||
rpc Login(EncryptedMessage) returns (EncryptedMessage) {}
|
||||
|
||||
// Sync enables peer synchronization. Each peer that is connected to this stream will receive updates from the server.
|
||||
// For example, if a new peer has been added to an account all other connected peers will receive this peer's Wireguard public key as an update
|
||||
// The initial SyncResponse contains all of the available peers so the local state can be refreshed
|
||||
// Returns encrypted SyncResponse in EncryptedMessage.Body
|
||||
rpc Sync(EncryptedMessage) returns (stream EncryptedMessage) {}
|
||||
|
||||
// Exposes a Wireguard public key of the Management service.
|
||||
// This key is used to support message encryption between client and server
|
||||
rpc GetServerKey(Empty) returns (ServerKeyResponse) {}
|
||||
|
||||
// health check endpoint
|
||||
rpc isHealthy(Empty) returns (Empty) {}
|
||||
|
||||
// Exposes a device authorization flow information
|
||||
// This is used for initiating a Oauth 2 device authorization grant flow
|
||||
// which will be used by our clients to Login.
|
||||
// EncryptedMessage of the request has a body of DeviceAuthorizationFlowRequest.
|
||||
// EncryptedMessage of the response has a body of DeviceAuthorizationFlow.
|
||||
rpc GetDeviceAuthorizationFlow(EncryptedMessage) returns (EncryptedMessage) {}
|
||||
|
||||
// Exposes a PKCE authorization code flow information
|
||||
// This is used for initiating a Oauth 2 authorization grant flow
|
||||
// with Proof Key for Code Exchange (PKCE) which will be used by our clients to Login.
|
||||
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
|
||||
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
|
||||
rpc GetPKCEAuthorizationFlow(EncryptedMessage) returns (EncryptedMessage) {}
|
||||
|
||||
// SyncMeta is used to sync metadata of the peer.
|
||||
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
|
||||
// sync meta will evaluate the checks and update the peer meta with the result.
|
||||
// EncryptedMessage of the request has a body of Empty.
|
||||
rpc SyncMeta(EncryptedMessage) returns (Empty) {}
|
||||
|
||||
// Logout logs out the peer and removes it from the management server
|
||||
rpc Logout(EncryptedMessage) returns (Empty) {}
|
||||
}
|
||||
|
||||
message EncryptedMessage {
|
||||
// Wireguard public key
|
||||
string wgPubKey = 1;
|
||||
|
||||
// encrypted message Body
|
||||
bytes body = 2;
|
||||
// Version of the Netbird Management Service protocol
|
||||
int32 version = 3;
|
||||
}
|
||||
|
||||
message SyncRequest {
|
||||
// Meta data of the peer
|
||||
PeerSystemMeta meta = 1;
|
||||
}
|
||||
|
||||
// SyncResponse represents a state that should be applied to the local peer (e.g. Netbird servers config as well as local peer and remote peers configs)
|
||||
message SyncResponse {
|
||||
|
||||
// Global config
|
||||
NetbirdConfig netbirdConfig = 1;
|
||||
|
||||
// Deprecated. Use NetworkMap.PeerConfig
|
||||
PeerConfig peerConfig = 2;
|
||||
|
||||
// Deprecated. Use NetworkMap.RemotePeerConfig
|
||||
repeated RemotePeerConfig remotePeers = 3;
|
||||
|
||||
// Indicates whether remotePeers array is empty or not to bypass protobuf null and empty array equality.
|
||||
// Deprecated. Use NetworkMap.remotePeersIsEmpty
|
||||
bool remotePeersIsEmpty = 4;
|
||||
|
||||
NetworkMap NetworkMap = 5;
|
||||
|
||||
// Posture checks to be evaluated by client
|
||||
repeated Checks Checks = 6;
|
||||
}
|
||||
|
||||
message SyncMetaRequest {
|
||||
// Meta data of the peer
|
||||
PeerSystemMeta meta = 1;
|
||||
}
|
||||
|
||||
message LoginRequest {
|
||||
// Pre-authorized setup key (can be empty)
|
||||
string setupKey = 1;
|
||||
// Meta data of the peer (e.g. name, os_name, os_version,
|
||||
PeerSystemMeta meta = 2;
|
||||
// SSO token (can be empty)
|
||||
string jwtToken = 3;
|
||||
// Can be absent for now.
|
||||
PeerKeys peerKeys = 4;
|
||||
|
||||
repeated string dnsLabels = 5;
|
||||
}
|
||||
|
||||
// PeerKeys is additional peer info like SSH pub key and WireGuard public key.
|
||||
// This message is sent on Login or register requests, or when a key rotation has to happen.
|
||||
message PeerKeys {
|
||||
|
||||
// sshPubKey represents a public SSH key of the peer. Can be absent.
|
||||
bytes sshPubKey = 1;
|
||||
// wgPubKey represents a public WireGuard key of the peer. Can be absent.
|
||||
bytes wgPubKey = 2;
|
||||
}
|
||||
|
||||
// Environment is part of the PeerSystemMeta and describes the environment the agent is running in.
|
||||
message Environment {
|
||||
// cloud is the cloud provider the agent is running in if applicable.
|
||||
string cloud = 1;
|
||||
// platform is the platform the agent is running on if applicable.
|
||||
string platform = 2;
|
||||
}
|
||||
|
||||
// File represents a file on the system.
|
||||
message File {
|
||||
// path is the path to the file.
|
||||
string path = 1;
|
||||
// exist indicate whether the file exists.
|
||||
bool exist = 2;
|
||||
// processIsRunning indicates whether the file is a running process or not.
|
||||
bool processIsRunning = 3;
|
||||
}
|
||||
|
||||
message Flags {
|
||||
bool rosenpassEnabled = 1;
|
||||
bool rosenpassPermissive = 2;
|
||||
bool serverSSHAllowed = 3;
|
||||
|
||||
bool disableClientRoutes = 4;
|
||||
bool disableServerRoutes = 5;
|
||||
bool disableDNS = 6;
|
||||
bool disableFirewall = 7;
|
||||
bool blockLANAccess = 8;
|
||||
bool blockInbound = 9;
|
||||
|
||||
bool lazyConnectionEnabled = 10;
|
||||
}
|
||||
|
||||
// PeerSystemMeta is machine meta data like OS and version.
|
||||
message PeerSystemMeta {
|
||||
string hostname = 1;
|
||||
string goOS = 2;
|
||||
string kernel = 3;
|
||||
string core = 4;
|
||||
string platform = 5;
|
||||
string OS = 6;
|
||||
string netbirdVersion = 7;
|
||||
string uiVersion = 8;
|
||||
string kernelVersion = 9;
|
||||
string OSVersion = 10;
|
||||
repeated NetworkAddress networkAddresses = 11;
|
||||
string sysSerialNumber = 12;
|
||||
string sysProductName = 13;
|
||||
string sysManufacturer = 14;
|
||||
Environment environment = 15;
|
||||
repeated File files = 16;
|
||||
Flags flags = 17;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
// Global config
|
||||
NetbirdConfig netbirdConfig = 1;
|
||||
// Peer local config
|
||||
PeerConfig peerConfig = 2;
|
||||
// Posture checks to be evaluated by client
|
||||
repeated Checks Checks = 3;
|
||||
}
|
||||
|
||||
message ServerKeyResponse {
|
||||
// Server's Wireguard public key
|
||||
string key = 1;
|
||||
// Key expiration timestamp after which the key should be fetched again by the client
|
||||
google.protobuf.Timestamp expiresAt = 2;
|
||||
// Version of the Netbird Management Service protocol
|
||||
int32 version = 3;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
|
||||
// NetbirdConfig is a common configuration of any Netbird peer. It contains STUN, TURN, Signal and Management servers configurations
|
||||
message NetbirdConfig {
|
||||
// a list of STUN servers
|
||||
repeated HostConfig stuns = 1;
|
||||
// a list of TURN servers
|
||||
repeated ProtectedHostConfig turns = 2;
|
||||
|
||||
// a Signal server config
|
||||
HostConfig signal = 3;
|
||||
|
||||
RelayConfig relay = 4;
|
||||
|
||||
FlowConfig flow = 5;
|
||||
}
|
||||
|
||||
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
|
||||
message HostConfig {
|
||||
// URI of the resource e.g. turns://stun.netbird.io:4430 or signal.netbird.io:10000
|
||||
string uri = 1;
|
||||
Protocol protocol = 2;
|
||||
|
||||
enum Protocol {
|
||||
UDP = 0;
|
||||
TCP = 1;
|
||||
HTTP = 2;
|
||||
HTTPS = 3;
|
||||
DTLS = 4;
|
||||
}
|
||||
}
|
||||
|
||||
message RelayConfig {
|
||||
repeated string urls = 1;
|
||||
string tokenPayload = 2;
|
||||
string tokenSignature = 3;
|
||||
}
|
||||
|
||||
message FlowConfig {
|
||||
string url = 1;
|
||||
string tokenPayload = 2;
|
||||
string tokenSignature = 3;
|
||||
google.protobuf.Duration interval = 4;
|
||||
bool enabled = 5;
|
||||
|
||||
// counters determines if flow packets and bytes counters should be sent
|
||||
bool counters = 6;
|
||||
// exitNodeCollection determines if event collection on exit nodes should be enabled
|
||||
bool exitNodeCollection = 7;
|
||||
// dnsCollection determines if DNS event collection should be enabled
|
||||
bool dnsCollection = 8;
|
||||
}
|
||||
|
||||
// ProtectedHostConfig is similar to HostConfig but has additional user and password
|
||||
// Mostly used for TURN servers
|
||||
message ProtectedHostConfig {
|
||||
HostConfig hostConfig = 1;
|
||||
string user = 2;
|
||||
string password = 3;
|
||||
}
|
||||
|
||||
// PeerConfig represents a configuration of a "our" peer.
|
||||
// The properties are used to configure local Wireguard
|
||||
message PeerConfig {
|
||||
// Peer's virtual IP address within the Netbird VPN (a Wireguard address config)
|
||||
string address = 1;
|
||||
// Netbird DNS server (a Wireguard DNS config)
|
||||
string dns = 2;
|
||||
|
||||
// SSHConfig of the peer.
|
||||
SSHConfig sshConfig = 3;
|
||||
// Peer fully qualified domain name
|
||||
string fqdn = 4;
|
||||
|
||||
bool RoutingPeerDnsResolutionEnabled = 5;
|
||||
|
||||
bool LazyConnectionEnabled = 6;
|
||||
}
|
||||
|
||||
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
|
||||
message NetworkMap {
|
||||
// Serial is an ID of the network state to be used by clients to order updates.
|
||||
// The larger the Serial the newer the configuration.
|
||||
// E.g. the client app should keep track of this id locally and discard all the configurations with a lower value
|
||||
uint64 Serial = 1;
|
||||
|
||||
// PeerConfig represents configuration of a peer
|
||||
PeerConfig peerConfig = 2;
|
||||
|
||||
// RemotePeerConfig represents a list of remote peers that the receiver can connect to
|
||||
repeated RemotePeerConfig remotePeers = 3;
|
||||
|
||||
// Indicates whether remotePeers array is empty or not to bypass protobuf null and empty array equality.
|
||||
bool remotePeersIsEmpty = 4;
|
||||
|
||||
// List of routes to be applied
|
||||
repeated Route Routes = 5;
|
||||
|
||||
// DNS config to be applied
|
||||
DNSConfig DNSConfig = 6;
|
||||
|
||||
// RemotePeerConfig represents a list of remote peers that the receiver can connect to
|
||||
repeated RemotePeerConfig offlinePeers = 7;
|
||||
|
||||
// FirewallRule represents a list of firewall rules to be applied to peer
|
||||
repeated FirewallRule FirewallRules = 8;
|
||||
|
||||
// firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality.
|
||||
bool firewallRulesIsEmpty = 9;
|
||||
|
||||
// RoutesFirewallRules represents a list of routes firewall rules to be applied to peer
|
||||
repeated RouteFirewallRule routesFirewallRules = 10;
|
||||
|
||||
// RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality.
|
||||
bool routesFirewallRulesIsEmpty = 11;
|
||||
|
||||
repeated ForwardingRule forwardingRules = 12;
|
||||
}
|
||||
|
||||
// RemotePeerConfig represents a configuration of a remote peer.
|
||||
// The properties are used to configure WireGuard Peers sections
|
||||
message RemotePeerConfig {
|
||||
|
||||
// A WireGuard public key of a remote peer
|
||||
string wgPubKey = 1;
|
||||
|
||||
// WireGuard allowed IPs of a remote peer e.g. [10.30.30.1/32]
|
||||
repeated string allowedIps = 2;
|
||||
|
||||
// SSHConfig is a SSH config of the remote peer. SSHConfig.sshPubKey should be ignored because peer knows it's SSH key.
|
||||
SSHConfig sshConfig = 3;
|
||||
|
||||
// Peer fully qualified domain name
|
||||
string fqdn = 4;
|
||||
|
||||
string agentVersion = 5;
|
||||
}
|
||||
|
||||
// SSHConfig represents SSH configurations of a peer.
|
||||
message SSHConfig {
|
||||
// sshEnabled indicates whether a SSH server is enabled on this peer
|
||||
bool sshEnabled = 1;
|
||||
|
||||
// sshPubKey is a SSH public key of a peer to be added to authorized_hosts.
|
||||
// This property should be ignore if SSHConfig comes from PeerConfig.
|
||||
bytes sshPubKey = 2;
|
||||
}
|
||||
|
||||
// DeviceAuthorizationFlowRequest empty struct for future expansion
|
||||
message DeviceAuthorizationFlowRequest {}
|
||||
// DeviceAuthorizationFlow represents Device Authorization Flow information
|
||||
// that can be used by the client to login initiate a Oauth 2.0 device authorization grant flow
|
||||
// see https://datatracker.ietf.org/doc/html/rfc8628
|
||||
message DeviceAuthorizationFlow {
|
||||
// An IDP provider , (eg. Auth0)
|
||||
provider Provider = 1;
|
||||
ProviderConfig ProviderConfig = 2;
|
||||
|
||||
enum provider {
|
||||
HOSTED = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// PKCEAuthorizationFlowRequest empty struct for future expansion
|
||||
message PKCEAuthorizationFlowRequest {}
|
||||
|
||||
// PKCEAuthorizationFlow represents Authorization Code Flow information
|
||||
// that can be used by the client to login initiate a Oauth 2.0 authorization code grant flow
|
||||
// with Proof Key for Code Exchange (PKCE). See https://datatracker.ietf.org/doc/html/rfc7636
|
||||
message PKCEAuthorizationFlow {
|
||||
ProviderConfig ProviderConfig = 1;
|
||||
}
|
||||
|
||||
// ProviderConfig has all attributes needed to initiate a device/pkce authorization flow
|
||||
message ProviderConfig {
|
||||
// An IDP application client id
|
||||
string ClientID = 1;
|
||||
// An IDP application client secret
|
||||
string ClientSecret = 2;
|
||||
// An IDP API domain
|
||||
// Deprecated. Use a DeviceAuthEndpoint and TokenEndpoint
|
||||
string Domain = 3;
|
||||
// An Audience for validation
|
||||
string Audience = 4;
|
||||
// DeviceAuthEndpoint is an endpoint to request device authentication code.
|
||||
string DeviceAuthEndpoint = 5;
|
||||
// TokenEndpoint is an endpoint to request auth token.
|
||||
string TokenEndpoint = 6;
|
||||
// Scopes provides the scopes to be included in the token request
|
||||
string Scope = 7;
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
bool UseIDToken = 8;
|
||||
// AuthorizationEndpoint is the endpoint of an IDP manager where clients can obtain authorization code.
|
||||
string AuthorizationEndpoint = 9;
|
||||
// RedirectURLs handles authorization code from IDP manager
|
||||
repeated string RedirectURLs = 10;
|
||||
// DisablePromptLogin makes the PKCE flow to not prompt the user for login
|
||||
bool DisablePromptLogin = 11;
|
||||
// LoginFlags sets the PKCE flow login details
|
||||
uint32 LoginFlag = 12;
|
||||
}
|
||||
|
||||
// Route represents a route.Route object
|
||||
message Route {
|
||||
string ID = 1;
|
||||
string Network = 2;
|
||||
int64 NetworkType = 3;
|
||||
string Peer = 4;
|
||||
int64 Metric = 5;
|
||||
bool Masquerade = 6;
|
||||
string NetID = 7;
|
||||
repeated string Domains = 8;
|
||||
bool keepRoute = 9;
|
||||
}
|
||||
|
||||
// DNSConfig represents a dns.Update
|
||||
message DNSConfig {
|
||||
bool ServiceEnable = 1;
|
||||
repeated NameServerGroup NameServerGroups = 2;
|
||||
repeated CustomZone CustomZones = 3;
|
||||
}
|
||||
|
||||
// CustomZone represents a dns.CustomZone
|
||||
message CustomZone {
|
||||
string Domain = 1;
|
||||
repeated SimpleRecord Records = 2;
|
||||
}
|
||||
|
||||
// SimpleRecord represents a dns.SimpleRecord
|
||||
message SimpleRecord {
|
||||
string Name = 1;
|
||||
int64 Type = 2;
|
||||
string Class = 3;
|
||||
int64 TTL = 4;
|
||||
string RData = 5;
|
||||
}
|
||||
|
||||
// NameServerGroup represents a dns.NameServerGroup
|
||||
message NameServerGroup {
|
||||
repeated NameServer NameServers = 1;
|
||||
bool Primary = 2;
|
||||
repeated string Domains = 3;
|
||||
bool SearchDomainsEnabled = 4;
|
||||
}
|
||||
|
||||
// NameServer represents a dns.NameServer
|
||||
message NameServer {
|
||||
string IP = 1;
|
||||
int64 NSType = 2;
|
||||
int64 Port = 3;
|
||||
}
|
||||
|
||||
enum RuleProtocol {
|
||||
UNKNOWN = 0;
|
||||
ALL = 1;
|
||||
TCP = 2;
|
||||
UDP = 3;
|
||||
ICMP = 4;
|
||||
CUSTOM = 5;
|
||||
}
|
||||
|
||||
enum RuleDirection {
|
||||
IN = 0;
|
||||
OUT = 1;
|
||||
}
|
||||
|
||||
enum RuleAction {
|
||||
ACCEPT = 0;
|
||||
DROP = 1;
|
||||
}
|
||||
|
||||
|
||||
// FirewallRule represents a firewall rule
|
||||
message FirewallRule {
|
||||
string PeerIP = 1;
|
||||
RuleDirection Direction = 2;
|
||||
RuleAction Action = 3;
|
||||
RuleProtocol Protocol = 4;
|
||||
string Port = 5;
|
||||
PortInfo PortInfo = 6;
|
||||
|
||||
// PolicyID is the ID of the policy that this rule belongs to
|
||||
bytes PolicyID = 7;
|
||||
}
|
||||
|
||||
message NetworkAddress {
|
||||
string netIP = 1;
|
||||
string mac = 2;
|
||||
}
|
||||
|
||||
message Checks {
|
||||
repeated string Files = 1;
|
||||
}
|
||||
|
||||
|
||||
message PortInfo {
|
||||
oneof portSelection {
|
||||
uint32 port = 1;
|
||||
Range range = 2;
|
||||
}
|
||||
|
||||
message Range {
|
||||
uint32 start = 1;
|
||||
uint32 end = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// RouteFirewallRule signifies a firewall rule applicable for a routed network.
|
||||
message RouteFirewallRule {
|
||||
// sourceRanges IP ranges of the routing peers.
|
||||
repeated string sourceRanges = 1;
|
||||
|
||||
// Action to be taken by the firewall when the rule is applicable.
|
||||
RuleAction action = 2;
|
||||
|
||||
// Network prefix for the routed network.
|
||||
string destination = 3;
|
||||
|
||||
// Protocol of the routed network.
|
||||
RuleProtocol protocol = 4;
|
||||
|
||||
// Details about the port.
|
||||
PortInfo portInfo = 5;
|
||||
|
||||
// IsDynamic indicates if the route is a DNS route.
|
||||
bool isDynamic = 6;
|
||||
|
||||
// Domains is a list of domains for which the rule is applicable.
|
||||
repeated string domains = 7;
|
||||
|
||||
// CustomProtocol is a custom protocol ID.
|
||||
uint32 customProtocol = 8;
|
||||
|
||||
// PolicyID is the ID of the policy that this rule belongs to
|
||||
bytes PolicyID = 9;
|
||||
|
||||
// RouteID is the ID of the route that this rule belongs to
|
||||
string RouteID = 10;
|
||||
}
|
||||
|
||||
message ForwardingRule {
|
||||
// Protocol of the forwarding rule
|
||||
RuleProtocol protocol = 1;
|
||||
|
||||
// portInfo is the ingress destination port information, where the traffic arrives in the gateway node
|
||||
PortInfo destinationPort = 2;
|
||||
|
||||
// IP address of the translated address (remote peer) to send traffic to
|
||||
bytes translatedAddress = 3;
|
||||
|
||||
// Translated port information, where the traffic should be forwarded to
|
||||
PortInfo translatedPort = 4;
|
||||
}
|
||||
429
shared/management/proto/management_grpc.pb.go
Normal file
429
shared/management/proto/management_grpc.pb.go
Normal file
@@ -0,0 +1,429 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
// ManagementServiceClient is the client API for ManagementService service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type ManagementServiceClient interface {
|
||||
// Login logs in peer. In case server returns codes.PermissionDenied this endpoint can be used to register Peer providing LoginRequest.setupKey
|
||||
// Returns encrypted LoginResponse in EncryptedMessage.Body
|
||||
Login(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
|
||||
// Sync enables peer synchronization. Each peer that is connected to this stream will receive updates from the server.
|
||||
// For example, if a new peer has been added to an account all other connected peers will receive this peer's Wireguard public key as an update
|
||||
// The initial SyncResponse contains all of the available peers so the local state can be refreshed
|
||||
// Returns encrypted SyncResponse in EncryptedMessage.Body
|
||||
Sync(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (ManagementService_SyncClient, error)
|
||||
// Exposes a Wireguard public key of the Management service.
|
||||
// This key is used to support message encryption between client and server
|
||||
GetServerKey(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*ServerKeyResponse, error)
|
||||
// health check endpoint
|
||||
IsHealthy(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error)
|
||||
// Exposes a device authorization flow information
|
||||
// This is used for initiating a Oauth 2 device authorization grant flow
|
||||
// which will be used by our clients to Login.
|
||||
// EncryptedMessage of the request has a body of DeviceAuthorizationFlowRequest.
|
||||
// EncryptedMessage of the response has a body of DeviceAuthorizationFlow.
|
||||
GetDeviceAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
|
||||
// Exposes a PKCE authorization code flow information
|
||||
// This is used for initiating a Oauth 2 authorization grant flow
|
||||
// with Proof Key for Code Exchange (PKCE) which will be used by our clients to Login.
|
||||
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
|
||||
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
|
||||
GetPKCEAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
|
||||
// SyncMeta is used to sync metadata of the peer.
|
||||
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
|
||||
// sync meta will evaluate the checks and update the peer meta with the result.
|
||||
// EncryptedMessage of the request has a body of Empty.
|
||||
SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
|
||||
// Logout logs out the peer and removes it from the management server
|
||||
Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
|
||||
}
|
||||
|
||||
type managementServiceClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewManagementServiceClient(cc grpc.ClientConnInterface) ManagementServiceClient {
|
||||
return &managementServiceClient{cc}
|
||||
}
|
||||
|
||||
func (c *managementServiceClient) Login(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
|
||||
out := new(EncryptedMessage)
|
||||
err := c.cc.Invoke(ctx, "/management.ManagementService/Login", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *managementServiceClient) Sync(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (ManagementService_SyncClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &ManagementService_ServiceDesc.Streams[0], "/management.ManagementService/Sync", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &managementServiceSyncClient{stream}
|
||||
if err := x.ClientStream.SendMsg(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := x.ClientStream.CloseSend(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type ManagementService_SyncClient interface {
|
||||
Recv() (*EncryptedMessage, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type managementServiceSyncClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *managementServiceSyncClient) Recv() (*EncryptedMessage, error) {
|
||||
m := new(EncryptedMessage)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (c *managementServiceClient) GetServerKey(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*ServerKeyResponse, error) {
|
||||
out := new(ServerKeyResponse)
|
||||
err := c.cc.Invoke(ctx, "/management.ManagementService/GetServerKey", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *managementServiceClient) IsHealthy(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := c.cc.Invoke(ctx, "/management.ManagementService/isHealthy", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *managementServiceClient) GetDeviceAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
|
||||
out := new(EncryptedMessage)
|
||||
err := c.cc.Invoke(ctx, "/management.ManagementService/GetDeviceAuthorizationFlow", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *managementServiceClient) GetPKCEAuthorizationFlow(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
|
||||
out := new(EncryptedMessage)
|
||||
err := c.cc.Invoke(ctx, "/management.ManagementService/GetPKCEAuthorizationFlow", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *managementServiceClient) SyncMeta(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := c.cc.Invoke(ctx, "/management.ManagementService/SyncMeta", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *managementServiceClient) Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error) {
|
||||
out := new(Empty)
|
||||
err := c.cc.Invoke(ctx, "/management.ManagementService/Logout", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ManagementServiceServer is the server API for ManagementService service.
|
||||
// All implementations must embed UnimplementedManagementServiceServer
|
||||
// for forward compatibility
|
||||
type ManagementServiceServer interface {
|
||||
// Login logs in peer. In case server returns codes.PermissionDenied this endpoint can be used to register Peer providing LoginRequest.setupKey
|
||||
// Returns encrypted LoginResponse in EncryptedMessage.Body
|
||||
Login(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
|
||||
// Sync enables peer synchronization. Each peer that is connected to this stream will receive updates from the server.
|
||||
// For example, if a new peer has been added to an account all other connected peers will receive this peer's Wireguard public key as an update
|
||||
// The initial SyncResponse contains all of the available peers so the local state can be refreshed
|
||||
// Returns encrypted SyncResponse in EncryptedMessage.Body
|
||||
Sync(*EncryptedMessage, ManagementService_SyncServer) error
|
||||
// Exposes a Wireguard public key of the Management service.
|
||||
// This key is used to support message encryption between client and server
|
||||
GetServerKey(context.Context, *Empty) (*ServerKeyResponse, error)
|
||||
// health check endpoint
|
||||
IsHealthy(context.Context, *Empty) (*Empty, error)
|
||||
// Exposes a device authorization flow information
|
||||
// This is used for initiating a Oauth 2 device authorization grant flow
|
||||
// which will be used by our clients to Login.
|
||||
// EncryptedMessage of the request has a body of DeviceAuthorizationFlowRequest.
|
||||
// EncryptedMessage of the response has a body of DeviceAuthorizationFlow.
|
||||
GetDeviceAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
|
||||
// Exposes a PKCE authorization code flow information
|
||||
// This is used for initiating a Oauth 2 authorization grant flow
|
||||
// with Proof Key for Code Exchange (PKCE) which will be used by our clients to Login.
|
||||
// EncryptedMessage of the request has a body of PKCEAuthorizationFlowRequest.
|
||||
// EncryptedMessage of the response has a body of PKCEAuthorizationFlow.
|
||||
GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
|
||||
// SyncMeta is used to sync metadata of the peer.
|
||||
// After sync the peer if there is a change in peer posture check which needs to be evaluated by the client,
|
||||
// sync meta will evaluate the checks and update the peer meta with the result.
|
||||
// EncryptedMessage of the request has a body of Empty.
|
||||
SyncMeta(context.Context, *EncryptedMessage) (*Empty, error)
|
||||
// Logout logs out the peer and removes it from the management server
|
||||
Logout(context.Context, *EncryptedMessage) (*Empty, error)
|
||||
mustEmbedUnimplementedManagementServiceServer()
|
||||
}
|
||||
|
||||
// UnimplementedManagementServiceServer must be embedded to have forward compatible implementations.
|
||||
type UnimplementedManagementServiceServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedManagementServiceServer) Login(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Login not implemented")
|
||||
}
|
||||
func (UnimplementedManagementServiceServer) Sync(*EncryptedMessage, ManagementService_SyncServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Sync not implemented")
|
||||
}
|
||||
func (UnimplementedManagementServiceServer) GetServerKey(context.Context, *Empty) (*ServerKeyResponse, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetServerKey not implemented")
|
||||
}
|
||||
func (UnimplementedManagementServiceServer) IsHealthy(context.Context, *Empty) (*Empty, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method IsHealthy not implemented")
|
||||
}
|
||||
func (UnimplementedManagementServiceServer) GetDeviceAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetDeviceAuthorizationFlow not implemented")
|
||||
}
|
||||
func (UnimplementedManagementServiceServer) GetPKCEAuthorizationFlow(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetPKCEAuthorizationFlow not implemented")
|
||||
}
|
||||
func (UnimplementedManagementServiceServer) SyncMeta(context.Context, *EncryptedMessage) (*Empty, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method SyncMeta not implemented")
|
||||
}
|
||||
func (UnimplementedManagementServiceServer) Logout(context.Context, *EncryptedMessage) (*Empty, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Logout not implemented")
|
||||
}
|
||||
func (UnimplementedManagementServiceServer) mustEmbedUnimplementedManagementServiceServer() {}
|
||||
|
||||
// UnsafeManagementServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to ManagementServiceServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeManagementServiceServer interface {
|
||||
mustEmbedUnimplementedManagementServiceServer()
|
||||
}
|
||||
|
||||
func RegisterManagementServiceServer(s grpc.ServiceRegistrar, srv ManagementServiceServer) {
|
||||
s.RegisterService(&ManagementService_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _ManagementService_Login_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(EncryptedMessage)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagementServiceServer).Login(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/management.ManagementService/Login",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagementServiceServer).Login(ctx, req.(*EncryptedMessage))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ManagementService_Sync_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
m := new(EncryptedMessage)
|
||||
if err := stream.RecvMsg(m); err != nil {
|
||||
return err
|
||||
}
|
||||
return srv.(ManagementServiceServer).Sync(m, &managementServiceSyncServer{stream})
|
||||
}
|
||||
|
||||
type ManagementService_SyncServer interface {
|
||||
Send(*EncryptedMessage) error
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type managementServiceSyncServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *managementServiceSyncServer) Send(m *EncryptedMessage) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func _ManagementService_GetServerKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(Empty)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagementServiceServer).GetServerKey(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/management.ManagementService/GetServerKey",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagementServiceServer).GetServerKey(ctx, req.(*Empty))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ManagementService_IsHealthy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(Empty)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagementServiceServer).IsHealthy(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/management.ManagementService/isHealthy",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagementServiceServer).IsHealthy(ctx, req.(*Empty))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ManagementService_GetDeviceAuthorizationFlow_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(EncryptedMessage)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagementServiceServer).GetDeviceAuthorizationFlow(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/management.ManagementService/GetDeviceAuthorizationFlow",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagementServiceServer).GetDeviceAuthorizationFlow(ctx, req.(*EncryptedMessage))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ManagementService_GetPKCEAuthorizationFlow_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(EncryptedMessage)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagementServiceServer).GetPKCEAuthorizationFlow(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/management.ManagementService/GetPKCEAuthorizationFlow",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagementServiceServer).GetPKCEAuthorizationFlow(ctx, req.(*EncryptedMessage))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ManagementService_SyncMeta_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(EncryptedMessage)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagementServiceServer).SyncMeta(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/management.ManagementService/SyncMeta",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagementServiceServer).SyncMeta(ctx, req.(*EncryptedMessage))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ManagementService_Logout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(EncryptedMessage)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagementServiceServer).Logout(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/management.ManagementService/Logout",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagementServiceServer).Logout(ctx, req.(*EncryptedMessage))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
// ManagementService_ServiceDesc is the grpc.ServiceDesc for ManagementService service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var ManagementService_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "management.ManagementService",
|
||||
HandlerType: (*ManagementServiceServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Login",
|
||||
Handler: _ManagementService_Login_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetServerKey",
|
||||
Handler: _ManagementService_GetServerKey_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "isHealthy",
|
||||
Handler: _ManagementService_IsHealthy_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetDeviceAuthorizationFlow",
|
||||
Handler: _ManagementService_GetDeviceAuthorizationFlow_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "GetPKCEAuthorizationFlow",
|
||||
Handler: _ManagementService_GetPKCEAuthorizationFlow_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "SyncMeta",
|
||||
Handler: _ManagementService_SyncMeta_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Logout",
|
||||
Handler: _ManagementService_Logout_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "Sync",
|
||||
Handler: _ManagementService_Sync_Handler,
|
||||
ServerStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "management.proto",
|
||||
}
|
||||
14
shared/relay/auth/allow/allow_all.go
Normal file
14
shared/relay/auth/allow/allow_all.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package allow
|
||||
|
||||
// Auth is a Validator that allows all connections.
|
||||
// Used this for testing purposes only.
|
||||
type Auth struct {
|
||||
}
|
||||
|
||||
func (a *Auth) Validate(any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Auth) ValidateHelloMsgType(any) error {
|
||||
return nil
|
||||
}
|
||||
26
shared/relay/auth/doc.go
Normal file
26
shared/relay/auth/doc.go
Normal file
@@ -0,0 +1,26 @@
|
||||
/*
|
||||
Package auth manages the authentication process with the relay server.
|
||||
|
||||
Key Components:
|
||||
|
||||
Validator: The Validator interface defines the Validate method. Any type that provides this method can be used as a
|
||||
Validator.
|
||||
|
||||
Methods:
|
||||
|
||||
Validate(func() hash.Hash, any): This method is defined in the Validator interface and is used to validate the authentication.
|
||||
|
||||
Usage:
|
||||
|
||||
To create a new AllowAllAuth validator, simply instantiate it:
|
||||
|
||||
validator := &allow.Auth{}
|
||||
|
||||
To validate the authentication, use the Validate method:
|
||||
|
||||
err := validator.Validate(sha256.New, any)
|
||||
|
||||
This package provides a simple and effective way to manage authentication with the relay server, ensuring that the
|
||||
peers are authenticated properly.
|
||||
*/
|
||||
package auth
|
||||
1
shared/relay/auth/go.sum
Normal file
1
shared/relay/auth/go.sum
Normal file
@@ -0,0 +1 @@
|
||||
golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
|
||||
8
shared/relay/auth/hmac/doc.go
Normal file
8
shared/relay/auth/hmac/doc.go
Normal file
@@ -0,0 +1,8 @@
|
||||
/*
|
||||
This package uses a similar HMAC method for authentication with the TURN server. The Management server provides the
|
||||
tokens for the peers. The peers manage these tokens in the token store. The token store is a simple thread safe store
|
||||
that keeps the tokens in memory. These tokens are used to authenticate the peers with the Relay server in the hello
|
||||
message.
|
||||
*/
|
||||
|
||||
package hmac
|
||||
44
shared/relay/auth/hmac/store.go
Normal file
44
shared/relay/auth/hmac/store.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
v2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
|
||||
)
|
||||
|
||||
// TokenStore is a simple in-memory store for token
|
||||
// With this can update the token in thread safe way
|
||||
type TokenStore struct {
|
||||
mu sync.Mutex
|
||||
token []byte
|
||||
}
|
||||
|
||||
func (a *TokenStore) UpdateToken(token *Token) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if token == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sig, err := base64.StdEncoding.DecodeString(token.Signature)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decode signature: %w", err)
|
||||
}
|
||||
|
||||
tok := v2.Token{
|
||||
AuthAlgo: v2.AuthAlgoHMACSHA256,
|
||||
Signature: sig,
|
||||
Payload: []byte(token.Payload),
|
||||
}
|
||||
|
||||
a.token = tok.Marshal()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *TokenStore) TokenBinary() []byte {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.token
|
||||
}
|
||||
94
shared/relay/auth/hmac/token.go
Normal file
94
shared/relay/auth/hmac/token.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/hmac"
|
||||
"encoding/base64"
|
||||
"encoding/gob"
|
||||
"fmt"
|
||||
"hash"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
Payload string
|
||||
Signature string
|
||||
}
|
||||
|
||||
func unmarshalToken(payload []byte) (Token, error) {
|
||||
var creds Token
|
||||
buffer := bytes.NewBuffer(payload)
|
||||
decoder := gob.NewDecoder(buffer)
|
||||
err := decoder.Decode(&creds)
|
||||
return creds, err
|
||||
}
|
||||
|
||||
// TimedHMAC generates a token with TTL and uses a pre-shared secret known to the relay server
|
||||
type TimedHMAC struct {
|
||||
secret string
|
||||
timeToLive time.Duration
|
||||
}
|
||||
|
||||
// NewTimedHMAC creates a new TimedHMAC instance
|
||||
func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC {
|
||||
return &TimedHMAC{
|
||||
secret: secret,
|
||||
timeToLive: timeToLive,
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC
|
||||
// hash of a timestamp with a preshared TURN secret
|
||||
func (m *TimedHMAC) GenerateToken(algo func() hash.Hash) (*Token, error) {
|
||||
timeAuth := time.Now().Add(m.timeToLive).Unix()
|
||||
timeStamp := strconv.FormatInt(timeAuth, 10)
|
||||
|
||||
checksum, err := m.generate(algo, timeStamp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Token{
|
||||
Payload: timeStamp,
|
||||
Signature: base64.StdEncoding.EncodeToString(checksum),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Validate checks if the token is valid
|
||||
func (m *TimedHMAC) Validate(algo func() hash.Hash, token Token) error {
|
||||
expectedMAC, err := m.generate(algo, token.Payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC)
|
||||
|
||||
if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) {
|
||||
return fmt.Errorf("signature mismatch")
|
||||
}
|
||||
|
||||
timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid payload: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().Unix() > timeAuthInt {
|
||||
return fmt.Errorf("expired token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *TimedHMAC) generate(algo func() hash.Hash, payload string) ([]byte, error) {
|
||||
mac := hmac.New(algo, []byte(m.secret))
|
||||
_, err := mac.Write([]byte(payload))
|
||||
if err != nil {
|
||||
log.Debugf("failed to generate token: %s", err)
|
||||
return nil, fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
|
||||
return mac.Sum(nil), nil
|
||||
}
|
||||
105
shared/relay/auth/hmac/token_test.go
Normal file
105
shared/relay/auth/hmac/token_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGenerateCredentials(t *testing.T) {
|
||||
secret := "secret"
|
||||
timeToLive := 1 * time.Hour
|
||||
v := NewTimedHMAC(secret, timeToLive)
|
||||
|
||||
creds, err := v.GenerateToken(sha1.New)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if creds.Payload == "" {
|
||||
t.Fatalf("expected non-empty payload")
|
||||
}
|
||||
|
||||
_, err = strconv.ParseInt(creds.Payload, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
|
||||
}
|
||||
|
||||
_, err = base64.StdEncoding.DecodeString(creds.Signature)
|
||||
if err != nil {
|
||||
t.Fatalf("expected signature to be base64 encoded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCredentials(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
manager := NewTimedHMAC(secret, timeToLive)
|
||||
|
||||
// Test valid token
|
||||
creds, err := manager.GenerateToken(sha1.New)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if err := manager.Validate(sha1.New, *creds); err != nil {
|
||||
t.Fatalf("expected valid token: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidSignature(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
manager := NewTimedHMAC(secret, timeToLive)
|
||||
|
||||
creds, err := manager.GenerateToken(sha256.New)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
invalidCreds := &Token{
|
||||
Payload: creds.Payload,
|
||||
Signature: "invalidsignature",
|
||||
}
|
||||
|
||||
if err = manager.Validate(sha1.New, *invalidCreds); err == nil {
|
||||
t.Fatalf("expected invalid token due to signature mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpired(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
v := NewTimedHMAC(secret, -1*time.Hour)
|
||||
expiredCreds, err := v.GenerateToken(sha256.New)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if err = v.Validate(sha1.New, *expiredCreds); err == nil {
|
||||
t.Fatalf("expected invalid token due to expiration")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidPayload(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
v := NewTimedHMAC(secret, timeToLive)
|
||||
|
||||
creds, err := v.GenerateToken(sha256.New)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
// Test invalid payload
|
||||
invalidPayloadCreds := &Token{
|
||||
Payload: "invalidtimestamp",
|
||||
Signature: creds.Signature,
|
||||
}
|
||||
|
||||
if err = v.Validate(sha1.New, *invalidPayloadCreds); err == nil {
|
||||
t.Fatalf("expected invalid token due to invalid payload")
|
||||
}
|
||||
}
|
||||
40
shared/relay/auth/hmac/v2/algo.go
Normal file
40
shared/relay/auth/hmac/v2/algo.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"hash"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthAlgoUnknown AuthAlgo = iota
|
||||
AuthAlgoHMACSHA256
|
||||
)
|
||||
|
||||
type AuthAlgo uint8
|
||||
|
||||
func (a AuthAlgo) String() string {
|
||||
switch a {
|
||||
case AuthAlgoHMACSHA256:
|
||||
return "HMAC-SHA256"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (a AuthAlgo) New() func() hash.Hash {
|
||||
switch a {
|
||||
case AuthAlgoHMACSHA256:
|
||||
return sha256.New
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a AuthAlgo) Size() int {
|
||||
switch a {
|
||||
case AuthAlgoHMACSHA256:
|
||||
return sha256.Size
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
45
shared/relay/auth/hmac/v2/generator.go
Normal file
45
shared/relay/auth/hmac/v2/generator.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"fmt"
|
||||
"hash"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Generator struct {
|
||||
algo func() hash.Hash
|
||||
algoType AuthAlgo
|
||||
secret []byte
|
||||
timeToLive time.Duration
|
||||
}
|
||||
|
||||
func NewGenerator(algo AuthAlgo, secret []byte, timeToLive time.Duration) (*Generator, error) {
|
||||
algoFunc := algo.New()
|
||||
if algoFunc == nil {
|
||||
return nil, fmt.Errorf("unsupported auth algorithm: %s", algo)
|
||||
}
|
||||
return &Generator{
|
||||
algo: algoFunc,
|
||||
algoType: algo,
|
||||
secret: secret,
|
||||
timeToLive: timeToLive,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (g *Generator) GenerateToken() (*Token, error) {
|
||||
expirationTime := time.Now().Add(g.timeToLive).Unix()
|
||||
|
||||
payload := []byte(strconv.FormatInt(expirationTime, 10))
|
||||
|
||||
h := hmac.New(g.algo, g.secret)
|
||||
h.Write(payload)
|
||||
signature := h.Sum(nil)
|
||||
|
||||
return &Token{
|
||||
AuthAlgo: g.algoType,
|
||||
Signature: signature,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
110
shared/relay/auth/hmac/v2/hmac_test.go
Normal file
110
shared/relay/auth/hmac/v2/hmac_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGenerateCredentials(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create generator: %v", err)
|
||||
}
|
||||
|
||||
token, err := g.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if len(token.Payload) == 0 {
|
||||
t.Fatalf("expected non-empty payload")
|
||||
}
|
||||
|
||||
_, err = strconv.ParseInt(string(token.Payload), 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateCredentials(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create generator: %v", err)
|
||||
}
|
||||
|
||||
token, err := g.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
v := NewValidator([]byte(secret))
|
||||
if err := v.Validate(token.Marshal()); err != nil {
|
||||
t.Fatalf("expected valid token: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidSignature(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create generator: %v", err)
|
||||
}
|
||||
|
||||
token, err := g.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
token.Signature = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||
|
||||
v := NewValidator([]byte(secret))
|
||||
if err := v.Validate(token.Marshal()); err == nil {
|
||||
t.Fatalf("expected valid token: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpired(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := -1 * time.Hour
|
||||
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create generator: %v", err)
|
||||
}
|
||||
|
||||
token, err := g.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
v := NewValidator([]byte(secret))
|
||||
if err := v.Validate(token.Marshal()); err == nil {
|
||||
t.Fatalf("expected valid token: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidPayload(t *testing.T) {
|
||||
secret := "supersecret"
|
||||
timeToLive := 1 * time.Hour
|
||||
g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create generator: %v", err)
|
||||
}
|
||||
|
||||
token, err := g.GenerateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
token.Payload = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||
|
||||
v := NewValidator([]byte(secret))
|
||||
if err := v.Validate(token.Marshal()); err == nil {
|
||||
t.Fatalf("expected invalid token due to invalid payload")
|
||||
}
|
||||
}
|
||||
39
shared/relay/auth/hmac/v2/token.go
Normal file
39
shared/relay/auth/hmac/v2/token.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package v2
|
||||
|
||||
import "errors"
|
||||
|
||||
type Token struct {
|
||||
AuthAlgo AuthAlgo
|
||||
Signature []byte
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
func (t *Token) Marshal() []byte {
|
||||
size := 1 + len(t.Signature) + len(t.Payload)
|
||||
|
||||
buf := make([]byte, size)
|
||||
|
||||
buf[0] = byte(t.AuthAlgo)
|
||||
copy(buf[1:], t.Signature)
|
||||
copy(buf[1+len(t.Signature):], t.Payload)
|
||||
|
||||
return buf
|
||||
}
|
||||
|
||||
func UnmarshalToken(data []byte) (*Token, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, errors.New("invalid token data")
|
||||
}
|
||||
|
||||
algo := AuthAlgo(data[0])
|
||||
sigSize := algo.Size()
|
||||
if len(data) < 1+sigSize {
|
||||
return nil, errors.New("invalid token data: insufficient length")
|
||||
}
|
||||
|
||||
return &Token{
|
||||
AuthAlgo: algo,
|
||||
Signature: data[1 : 1+sigSize],
|
||||
Payload: data[1+sigSize:],
|
||||
}, nil
|
||||
}
|
||||
59
shared/relay/auth/hmac/v2/validator.go
Normal file
59
shared/relay/auth/hmac/v2/validator.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
const minLengthUnixTimestamp = 10
|
||||
|
||||
type Validator struct {
|
||||
secret []byte
|
||||
}
|
||||
|
||||
func NewValidator(secret []byte) *Validator {
|
||||
return &Validator{secret: secret}
|
||||
}
|
||||
|
||||
func (v *Validator) Validate(data any) error {
|
||||
d, ok := data.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid data type")
|
||||
}
|
||||
|
||||
token, err := UnmarshalToken(d)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshal token: %w", err)
|
||||
}
|
||||
|
||||
if len(token.Payload) < minLengthUnixTimestamp {
|
||||
return errors.New("invalid payload: insufficient length")
|
||||
}
|
||||
|
||||
hashFunc := token.AuthAlgo.New()
|
||||
if hashFunc == nil {
|
||||
return fmt.Errorf("unsupported auth algorithm: %s", token.AuthAlgo)
|
||||
}
|
||||
|
||||
h := hmac.New(hashFunc, v.secret)
|
||||
h.Write(token.Payload)
|
||||
expectedMAC := h.Sum(nil)
|
||||
|
||||
if !hmac.Equal(token.Signature, expectedMAC) {
|
||||
return errors.New("invalid signature")
|
||||
}
|
||||
|
||||
timestamp, err := strconv.ParseInt(string(token.Payload), 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid payload: %w", err)
|
||||
}
|
||||
|
||||
if time.Now().Unix() > timestamp {
|
||||
return fmt.Errorf("expired token")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
33
shared/relay/auth/hmac/validator.go
Normal file
33
shared/relay/auth/hmac/validator.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package hmac
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type TimedHMACValidator struct {
|
||||
*TimedHMAC
|
||||
}
|
||||
|
||||
func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACValidator {
|
||||
ta := NewTimedHMAC(secret, duration)
|
||||
return &TimedHMACValidator{
|
||||
ta,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *TimedHMACValidator) Validate(credentials any) error {
|
||||
b, ok := credentials.([]byte)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid credentials type")
|
||||
}
|
||||
c, err := unmarshalToken(b)
|
||||
if err != nil {
|
||||
log.Debugf("failed to unmarshal token: %s", err)
|
||||
return err
|
||||
}
|
||||
return a.TimedHMAC.Validate(sha256.New, c)
|
||||
}
|
||||
28
shared/relay/auth/validator.go
Normal file
28
shared/relay/auth/validator.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
authv2 "github.com/netbirdio/netbird/shared/relay/auth/hmac/v2"
|
||||
)
|
||||
|
||||
type TimedHMACValidator struct {
|
||||
authenticatorV2 *authv2.Validator
|
||||
authenticator *auth.TimedHMACValidator
|
||||
}
|
||||
|
||||
func NewTimedHMACValidator(secret []byte, duration time.Duration) *TimedHMACValidator {
|
||||
return &TimedHMACValidator{
|
||||
authenticatorV2: authv2.NewValidator(secret),
|
||||
authenticator: auth.NewTimedHMACValidator(string(secret), duration),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *TimedHMACValidator) Validate(credentials any) error {
|
||||
return a.authenticatorV2.Validate(credentials)
|
||||
}
|
||||
|
||||
func (a *TimedHMACValidator) ValidateHelloMsgType(credentials any) error {
|
||||
return a.authenticator.Validate(credentials)
|
||||
}
|
||||
13
shared/relay/client/addr.go
Normal file
13
shared/relay/client/addr.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package client
|
||||
|
||||
type RelayAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a RelayAddr) Network() string {
|
||||
return "relay"
|
||||
}
|
||||
|
||||
func (a RelayAddr) String() string {
|
||||
return a.addr
|
||||
}
|
||||
665
shared/relay/client/client.go
Normal file
665
shared/relay/client/client.go
Normal file
@@ -0,0 +1,665 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/quic"
|
||||
"github.com/netbirdio/netbird/shared/relay/client/dialer/ws"
|
||||
"github.com/netbirdio/netbird/relay/healthcheck"
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 8820
|
||||
serverResponseTimeout = 8 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrConnAlreadyExists = fmt.Errorf("connection already exists")
|
||||
)
|
||||
|
||||
type internalStopFlag struct {
|
||||
sync.Mutex
|
||||
stop bool
|
||||
}
|
||||
|
||||
func newInternalStopFlag() *internalStopFlag {
|
||||
return &internalStopFlag{}
|
||||
}
|
||||
|
||||
func (isf *internalStopFlag) set() {
|
||||
isf.Lock()
|
||||
defer isf.Unlock()
|
||||
isf.stop = true
|
||||
}
|
||||
|
||||
func (isf *internalStopFlag) isSet() bool {
|
||||
isf.Lock()
|
||||
defer isf.Unlock()
|
||||
return isf.stop
|
||||
}
|
||||
|
||||
// Msg carry the payload from the server to the client. With this struct, the net.Conn can free the buffer.
|
||||
type Msg struct {
|
||||
Payload []byte
|
||||
|
||||
bufPool *sync.Pool
|
||||
bufPtr *[]byte
|
||||
}
|
||||
|
||||
func (m *Msg) Free() {
|
||||
m.bufPool.Put(m.bufPtr)
|
||||
}
|
||||
|
||||
// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
|
||||
// server and forwarding them to the upper layer content reader.
|
||||
type connContainer struct {
|
||||
log *log.Entry
|
||||
conn *Conn
|
||||
messages chan Msg
|
||||
msgChanLock sync.Mutex
|
||||
closed bool // flag to check if channel is closed
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &connContainer{
|
||||
log: log,
|
||||
conn: conn,
|
||||
messages: messages,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *connContainer) writeMsg(msg Msg) {
|
||||
cc.msgChanLock.Lock()
|
||||
defer cc.msgChanLock.Unlock()
|
||||
|
||||
if cc.closed {
|
||||
msg.Free()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case cc.messages <- msg:
|
||||
case <-cc.ctx.Done():
|
||||
msg.Free()
|
||||
default:
|
||||
msg.Free()
|
||||
}
|
||||
}
|
||||
|
||||
func (cc *connContainer) close() {
|
||||
cc.cancel()
|
||||
|
||||
cc.msgChanLock.Lock()
|
||||
defer cc.msgChanLock.Unlock()
|
||||
|
||||
if cc.closed {
|
||||
return
|
||||
}
|
||||
|
||||
cc.closed = true
|
||||
close(cc.messages)
|
||||
|
||||
for msg := range cc.messages {
|
||||
msg.Free()
|
||||
}
|
||||
}
|
||||
|
||||
// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
|
||||
// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
|
||||
// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
|
||||
// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
|
||||
type Client struct {
|
||||
log *log.Entry
|
||||
connectionURL string
|
||||
authTokenStore *auth.TokenStore
|
||||
hashedID messages.PeerID
|
||||
|
||||
bufPool *sync.Pool
|
||||
|
||||
relayConn net.Conn
|
||||
conns map[messages.PeerID]*connContainer
|
||||
serviceIsRunning bool
|
||||
mu sync.Mutex // protect serviceIsRunning and conns
|
||||
readLoopMutex sync.Mutex
|
||||
wgReadLoop sync.WaitGroup
|
||||
instanceURL *RelayAddr
|
||||
muInstanceURL sync.Mutex
|
||||
|
||||
onDisconnectListener func(string)
|
||||
listenerMutex sync.Mutex
|
||||
|
||||
stateSubscription *PeersStateSubscription
|
||||
}
|
||||
|
||||
// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
|
||||
func NewClient(serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
|
||||
hashedID := messages.HashID(peerID)
|
||||
relayLog := log.WithFields(log.Fields{"relay": serverURL})
|
||||
|
||||
c := &Client{
|
||||
log: relayLog,
|
||||
connectionURL: serverURL,
|
||||
authTokenStore: authTokenStore,
|
||||
hashedID: hashedID,
|
||||
bufPool: &sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, bufferSize)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
conns: make(map[messages.PeerID]*connContainer),
|
||||
}
|
||||
|
||||
c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedID)
|
||||
return c
|
||||
}
|
||||
|
||||
// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
c.log.Infof("connecting to relay server")
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.serviceIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
instanceURL, err := c.connect(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.muInstanceURL.Lock()
|
||||
c.instanceURL = instanceURL
|
||||
c.muInstanceURL.Unlock()
|
||||
|
||||
c.stateSubscription = NewPeersStateSubscription(c.log, c.relayConn, c.closeConnsByPeerID)
|
||||
|
||||
c.log = c.log.WithField("relay", instanceURL.String())
|
||||
c.log.Infof("relay connection established")
|
||||
|
||||
c.serviceIsRunning = true
|
||||
|
||||
internallyStoppedFlag := newInternalStopFlag()
|
||||
hc := healthcheck.NewReceiver(c.log)
|
||||
go c.listenForStopEvents(ctx, hc, c.relayConn, internallyStoppedFlag)
|
||||
|
||||
c.wgReadLoop.Add(1)
|
||||
go c.readLoop(hc, c.relayConn, internallyStoppedFlag)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
|
||||
// to the relay server, the function will block until the connection is established or timed out. Otherwise,
|
||||
// it will return immediately.
|
||||
// It block until the server confirm the peer is online.
|
||||
// todo: what should happen if call with the same peerID with multiple times?
|
||||
func (c *Client) OpenConn(ctx context.Context, dstPeerID string) (net.Conn, error) {
|
||||
peerID := messages.HashID(dstPeerID)
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
return nil, fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
_, ok := c.conns[peerID]
|
||||
if ok {
|
||||
c.mu.Unlock()
|
||||
return nil, ErrConnAlreadyExists
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if err := c.stateSubscription.WaitToBeOnlineAndSubscribe(ctx, peerID); err != nil {
|
||||
c.log.Errorf("peer not available: %s, %s", peerID, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.log.Infof("remote peer is available, prepare the relayed connection: %s", peerID)
|
||||
msgChannel := make(chan Msg, 100)
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
return nil, fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
|
||||
c.muInstanceURL.Lock()
|
||||
instanceURL := c.instanceURL
|
||||
c.muInstanceURL.Unlock()
|
||||
conn := NewConn(c, peerID, msgChannel, instanceURL)
|
||||
|
||||
_, ok = c.conns[peerID]
|
||||
if ok {
|
||||
c.mu.Unlock()
|
||||
_ = conn.Close()
|
||||
return nil, ErrConnAlreadyExists
|
||||
}
|
||||
c.conns[peerID] = newConnContainer(c.log, conn, msgChannel)
|
||||
c.mu.Unlock()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// ServerInstanceURL returns the address of the relay server. It could change after the close and reopen the connection.
|
||||
func (c *Client) ServerInstanceURL() (string, error) {
|
||||
c.muInstanceURL.Lock()
|
||||
defer c.muInstanceURL.Unlock()
|
||||
if c.instanceURL == nil {
|
||||
return "", fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
return c.instanceURL.String(), nil
|
||||
}
|
||||
|
||||
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
|
||||
func (c *Client) SetOnDisconnectListener(fn func(string)) {
|
||||
c.listenerMutex.Lock()
|
||||
defer c.listenerMutex.Unlock()
|
||||
c.onDisconnectListener = fn
|
||||
}
|
||||
|
||||
// HasConns returns true if there are connections.
|
||||
func (c *Client) HasConns() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return len(c.conns) > 0
|
||||
}
|
||||
|
||||
func (c *Client) Ready() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.serviceIsRunning
|
||||
}
|
||||
|
||||
// Close closes the connection to the relay server and all connections to other peers.
|
||||
func (c *Client) Close() error {
|
||||
return c.close(true)
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) (*RelayAddr, error) {
|
||||
rd := dialer.NewRaceDial(c.log, dialer.DefaultConnectionTimeout, c.connectionURL, quic.Dialer{}, ws.Dialer{})
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.relayConn = conn
|
||||
|
||||
instanceURL, err := c.handShake(ctx)
|
||||
if err != nil {
|
||||
cErr := conn.Close()
|
||||
if cErr != nil {
|
||||
c.log.Errorf("failed to close connection: %s", cErr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return instanceURL, nil
|
||||
}
|
||||
|
||||
func (c *Client) handShake(ctx context.Context) (*RelayAddr, error) {
|
||||
msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to marshal auth message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to send auth message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
buf := make([]byte, messages.MaxHandshakeRespSize)
|
||||
n, err := c.readWithTimeout(ctx, buf)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to read auth response: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = messages.ValidateVersion(buf[:n])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validate version: %w", err)
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMessageType(buf[:n])
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to determine message type: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if msgType != messages.MsgTypeAuthResponse {
|
||||
c.log.Errorf("unexpected message type: %s", msgType)
|
||||
return nil, fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
addr, err := messages.UnmarshalAuthResponse(buf[:n])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RelayAddr{addr: addr}, nil
|
||||
}
|
||||
|
||||
func (c *Client) readLoop(hc *healthcheck.Receiver, relayConn net.Conn, internallyStoppedFlag *internalStopFlag) {
|
||||
var (
|
||||
errExit error
|
||||
n int
|
||||
)
|
||||
for {
|
||||
bufPtr := c.bufPool.Get().(*[]byte)
|
||||
buf := *bufPtr
|
||||
n, errExit = relayConn.Read(buf)
|
||||
if errExit != nil {
|
||||
c.log.Infof("start to Relay read loop exit")
|
||||
c.mu.Lock()
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Errorf("failed to read message from relay server: %s", errExit)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.bufPool.Put(bufPtr)
|
||||
break
|
||||
}
|
||||
|
||||
buf = buf[:n]
|
||||
|
||||
_, err := messages.ValidateVersion(buf)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to validate protocol version: %s", err)
|
||||
c.bufPool.Put(bufPtr)
|
||||
continue
|
||||
}
|
||||
|
||||
msgType, err := messages.DetermineServerMessageType(buf)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to determine message type: %s", err)
|
||||
c.bufPool.Put(bufPtr)
|
||||
continue
|
||||
}
|
||||
|
||||
if !c.handleMsg(msgType, buf, bufPtr, hc, internallyStoppedFlag) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
hc.Stop()
|
||||
|
||||
c.stateSubscription.Cleanup()
|
||||
c.wgReadLoop.Done()
|
||||
_ = c.close(false)
|
||||
c.notifyDisconnected()
|
||||
}
|
||||
|
||||
func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) {
|
||||
switch msgType {
|
||||
case messages.MsgTypeHealthCheck:
|
||||
c.handleHealthCheck(hc, internallyStoppedFlag)
|
||||
c.bufPool.Put(bufPtr)
|
||||
case messages.MsgTypeTransport:
|
||||
return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
|
||||
case messages.MsgTypePeersOnline:
|
||||
c.handlePeersOnlineMsg(buf)
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
case messages.MsgTypePeersWentOffline:
|
||||
c.handlePeersWentOfflineMsg(buf)
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
case messages.MsgTypeClose:
|
||||
c.log.Debugf("relay connection close by server")
|
||||
c.bufPool.Put(bufPtr)
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) handleHealthCheck(hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) {
|
||||
msg := messages.MarshalHealthcheck()
|
||||
_, wErr := c.relayConn.Write(msg)
|
||||
if wErr != nil {
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Errorf("failed to send heartbeat: %s", wErr)
|
||||
}
|
||||
}
|
||||
hc.Heartbeat()
|
||||
}
|
||||
|
||||
func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppedFlag *internalStopFlag) bool {
|
||||
peerID, payload, err := messages.UnmarshalTransportMsg(buf)
|
||||
if err != nil {
|
||||
if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
|
||||
c.log.Errorf("failed to parse transport message: %v", err)
|
||||
}
|
||||
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
c.bufPool.Put(bufPtr)
|
||||
return false
|
||||
}
|
||||
container, ok := c.conns[*peerID]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
c.log.Errorf("peer not found: %s", peerID.String())
|
||||
c.bufPool.Put(bufPtr)
|
||||
return true
|
||||
}
|
||||
msg := Msg{
|
||||
bufPool: c.bufPool,
|
||||
bufPtr: bufPtr,
|
||||
Payload: payload,
|
||||
}
|
||||
container.writeMsg(msg)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) writeTo(connReference *Conn, dstID messages.PeerID, payload []byte) (int, error) {
|
||||
c.mu.Lock()
|
||||
conn, ok := c.conns[dstID]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
if conn.conn != connReference {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// todo: use buffer pool instead of create new transport msg.
|
||||
msg, err := messages.MarshalTransportMsg(dstID, payload)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to marshal transport message: %s", err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// the write always return with 0 length because the underling does not support the size feedback.
|
||||
_, err = c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to write transport message: %s", err)
|
||||
}
|
||||
return len(payload), err
|
||||
}
|
||||
|
||||
func (c *Client) listenForStopEvents(ctx context.Context, hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-hc.OnTimeout:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.log.Errorf("health check timeout")
|
||||
internalStopFlag.set()
|
||||
if err := conn.Close(); err != nil {
|
||||
// ignore the err handling because the readLoop will handle it
|
||||
c.log.Warnf("failed to close connection: %s", err)
|
||||
}
|
||||
return
|
||||
case <-ctx.Done():
|
||||
err := c.close(true)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to teardown connection: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) closeAllConns() {
|
||||
for _, container := range c.conns {
|
||||
container.close()
|
||||
}
|
||||
c.conns = make(map[messages.PeerID]*connContainer)
|
||||
}
|
||||
|
||||
func (c *Client) closeConnsByPeerID(peerIDs []messages.PeerID) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
container, ok := c.conns[peerID]
|
||||
if !ok {
|
||||
c.log.Warnf("can not close connection, peer not found: %s", peerID)
|
||||
continue
|
||||
}
|
||||
|
||||
container.log.Infof("remote peer has been disconnected, free up connection: %s", peerID)
|
||||
container.close()
|
||||
delete(c.conns, peerID)
|
||||
}
|
||||
|
||||
if err := c.stateSubscription.UnsubscribeStateChange(peerIDs); err != nil {
|
||||
c.log.Errorf("failed to unsubscribe from peer state change: %s, %s", peerIDs, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) closeConn(connReference *Conn, id messages.PeerID) error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
container, ok := c.conns[id]
|
||||
if !ok {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
if container.conn != connReference {
|
||||
return fmt.Errorf("conn reference mismatch")
|
||||
}
|
||||
|
||||
if err := c.stateSubscription.UnsubscribeStateChange([]messages.PeerID{id}); err != nil {
|
||||
container.log.Errorf("failed to unsubscribe from peer state change: %s", err)
|
||||
}
|
||||
|
||||
c.log.Infof("free up connection to peer: %s", id)
|
||||
delete(c.conns, id)
|
||||
container.close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) close(gracefullyExit bool) error {
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
var err error
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
c.log.Warn("relay connection was already marked as not running")
|
||||
return nil
|
||||
}
|
||||
c.serviceIsRunning = false
|
||||
|
||||
c.muInstanceURL.Lock()
|
||||
c.instanceURL = nil
|
||||
c.muInstanceURL.Unlock()
|
||||
|
||||
c.log.Infof("closing all peer connections")
|
||||
c.closeAllConns()
|
||||
if gracefullyExit {
|
||||
c.writeCloseMsg()
|
||||
}
|
||||
err = c.relayConn.Close()
|
||||
c.mu.Unlock()
|
||||
|
||||
c.log.Infof("waiting for read loop to close")
|
||||
c.wgReadLoop.Wait()
|
||||
c.log.Infof("relay connection closed")
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) notifyDisconnected() {
|
||||
c.listenerMutex.Lock()
|
||||
defer c.listenerMutex.Unlock()
|
||||
|
||||
if c.onDisconnectListener == nil {
|
||||
return
|
||||
}
|
||||
go c.onDisconnectListener(c.connectionURL)
|
||||
}
|
||||
|
||||
func (c *Client) writeCloseMsg() {
|
||||
msg := messages.MarshalCloseMsg()
|
||||
_, err := c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to send close message: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) readWithTimeout(ctx context.Context, buf []byte) (int, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, serverResponseTimeout)
|
||||
defer cancel()
|
||||
|
||||
readDone := make(chan struct{})
|
||||
var (
|
||||
n int
|
||||
err error
|
||||
)
|
||||
|
||||
go func() {
|
||||
n, err = c.relayConn.Read(buf)
|
||||
close(readDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, fmt.Errorf("read operation timed out")
|
||||
case <-readDone:
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) handlePeersOnlineMsg(buf []byte) {
|
||||
peersID, err := messages.UnmarshalPeersOnlineMsg(buf)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to unmarshal peers online msg: %s", err)
|
||||
return
|
||||
}
|
||||
c.stateSubscription.OnPeersOnline(peersID)
|
||||
}
|
||||
|
||||
func (c *Client) handlePeersWentOfflineMsg(buf []byte) {
|
||||
peersID, err := messages.UnMarshalPeersWentOffline(buf)
|
||||
if err != nil {
|
||||
c.log.Errorf("failed to unmarshal peers went offline msg: %s", err)
|
||||
return
|
||||
}
|
||||
c.stateSubscription.OnPeersWentOffline(peersID)
|
||||
}
|
||||
743
shared/relay/client/client_test.go
Normal file
743
shared/relay/client/client_test.go
Normal file
@@ -0,0 +1,743 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay/auth/allow"
|
||||
"github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
)
|
||||
|
||||
var (
|
||||
hmacTokenStore = &hmac.TokenStore{}
|
||||
serverListenAddr = "127.0.0.1:1234"
|
||||
serverURL = "rel://127.0.0.1:1234"
|
||||
serverCfg = server.Config{
|
||||
Meter: otel.Meter(""),
|
||||
ExposedAddress: serverURL,
|
||||
TLSSupport: false,
|
||||
AuthValidator: &allow.Auth{},
|
||||
}
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
_ = util.InitLog("debug", util.LogConsole)
|
||||
code := m.Run()
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srv, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
listenCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
err := srv.Listen(listenCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for server to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
t.Log("alice connecting to server")
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientAlice.Close()
|
||||
|
||||
t.Log("placeholder connecting to server")
|
||||
clientPlaceHolder := NewClient(serverURL, hmacTokenStore, "clientPlaceHolder")
|
||||
err = clientPlaceHolder.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientPlaceHolder.Close()
|
||||
|
||||
t.Log("Bob connecting to server")
|
||||
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||
err = clientBob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer clientBob.Close()
|
||||
|
||||
t.Log("Alice open connection to Bob")
|
||||
connAliceToBob, err := clientAlice.OpenConn(ctx, "bob")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
t.Log("Bob open connection to Alice")
|
||||
connBobToAlice, err := clientBob.OpenConn(ctx, "alice")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
log.Debugf("alice sent message to bob")
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := connBobToAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
log.Debugf("on new message from alice to bob")
|
||||
|
||||
if payload != string(buf[:n]) {
|
||||
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistration(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for server to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
_ = srv.Shutdown(ctx)
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
err = srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistrationTimeout(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind UDP server: %s", err)
|
||||
}
|
||||
defer func(fakeUDPListener *net.UDPConn) {
|
||||
_ = fakeUDPListener.Close()
|
||||
}(fakeUDPListener)
|
||||
|
||||
fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
|
||||
Port: 1234,
|
||||
IP: net.ParseIP("0.0.0.0"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind TCP server: %s", err)
|
||||
}
|
||||
defer func(fakeTCPListener *net.TCPListener) {
|
||||
_ = fakeTCPListener.Close()
|
||||
}(fakeTCPListener)
|
||||
|
||||
clientAlice := NewClient("127.0.0.1:1234", hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err == nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
log.Debugf("%s", err)
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close conn: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEcho(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
idAlice := "alice"
|
||||
idBob := "bob"
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close Alice client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientBob := NewClient(serverURL, hmacTokenStore, idBob)
|
||||
err = clientBob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := clientBob.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close Bob client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := connBobToAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
_, err = connBobToAlice.Write(buf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
n, err = connAliceToBob.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
if payload != string(buf[:n]) {
|
||||
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindToUnavailabePeer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
log.Infof("closing server")
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
_, err = clientAlice.OpenConn(ctx, "bob")
|
||||
if err == nil {
|
||||
t.Errorf("expected error when binding to unavailable peer, got nil")
|
||||
}
|
||||
|
||||
log.Infof("closing client")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBindReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
log.Infof("closing server")
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
clientBob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||
err = clientBob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
_, err = clientAlice.OpenConn(ctx, "bob")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
chBob, err := clientBob.OpenConn(ctx, "alice")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing client Alice")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
|
||||
clientAlice = NewClient(serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
chAlice, err := clientAlice.OpenConn(ctx, "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
testString := "hello alice, I am bob"
|
||||
_, err = chBob.Write([]byte(testString))
|
||||
if err == nil {
|
||||
t.Errorf("expected error when writing to channel, got nil")
|
||||
}
|
||||
|
||||
chBob, err = clientBob.OpenConn(ctx, "alice")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
_, err = chBob.Write([]byte(testString))
|
||||
if err != nil {
|
||||
t.Errorf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := chAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
if testString != string(buf[:n]) {
|
||||
t.Errorf("expected %s, got %s", testString, string(buf[:n]))
|
||||
}
|
||||
|
||||
log.Infof("closing client")
|
||||
err = clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
log.Infof("closing server")
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
bob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||
err = bob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
conn, err := clientAlice.OpenConn(ctx, "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
log.Infof("closing connection")
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
_, err = conn.Write([]byte("hello"))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected writing from closed connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseRelayConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
bob := NewClient(serverURL, hmacTokenStore, "bob")
|
||||
err = bob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, "alice")
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
conn, err := clientAlice.OpenConn(ctx, "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
_ = clientAlice.relayConn.Close()
|
||||
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
_, err = clientAlice.OpenConn(ctx, "bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseByServer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv1, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
err := srv1.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||
if err = relayClient.Connect(ctx); err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := relayClient.Close(); err != nil {
|
||||
log.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
disconnected := make(chan struct{})
|
||||
relayClient.SetOnDisconnectListener(func(_ string) {
|
||||
log.Infof("client disconnected")
|
||||
close(disconnected)
|
||||
})
|
||||
|
||||
err = srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close server: %s", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-disconnected:
|
||||
case <-time.After(3 * time.Second):
|
||||
log.Errorf("timeout waiting for client to disconnect")
|
||||
}
|
||||
|
||||
_, err = relayClient.OpenConn(ctx, "bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseByClient(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
log.Debugf("connect by alice")
|
||||
relayClient := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||
err = relayClient.Connect(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
|
||||
err = relayClient.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close client: %s", err)
|
||||
}
|
||||
|
||||
_, err = relayClient.OpenConn(ctx, "bob")
|
||||
if err == nil {
|
||||
t.Errorf("unexpected opening connection to closed server")
|
||||
}
|
||||
|
||||
err = srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close server: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseNotDrainedChannel(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
idAlice := "alice"
|
||||
idBob := "bob"
|
||||
srvCfg := server.ListenerConfig{Address: serverListenAddr}
|
||||
srv, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv.Listen(srvCfg)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for servers to start
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewClient(serverURL, hmacTokenStore, idAlice)
|
||||
err = clientAlice.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := clientAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close Alice client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
clientBob := NewClient(serverURL, hmacTokenStore, idBob)
|
||||
err = clientBob.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to server: %s", err)
|
||||
}
|
||||
defer func() {
|
||||
err := clientBob.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close Bob client: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
connAliceToBob, err := clientAlice.OpenConn(ctx, idBob)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
connBobToAlice, err := clientBob.OpenConn(ctx, idAlice)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
// the internal channel buffer size is 2. So we should overflow it
|
||||
for i := 0; i < 5; i++ {
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// wait for delivery
|
||||
time.Sleep(1 * time.Second)
|
||||
err = connBobToAlice.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close channel: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForServerToStart(errChan chan error) error {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case <-time.After(300 * time.Millisecond):
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
74
shared/relay/client/conn.go
Normal file
74
shared/relay/client/conn.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
// Conn represent a connection to a relayed remote peer.
|
||||
type Conn struct {
|
||||
client *Client
|
||||
dstID messages.PeerID
|
||||
messageChan chan Msg
|
||||
instanceURL *RelayAddr
|
||||
}
|
||||
|
||||
// NewConn creates a new connection to a relayed remote peer.
|
||||
// client: the client instance, it used to send messages to the destination peer
|
||||
// dstID: the destination peer ID
|
||||
// messageChan: the channel where the messages will be received
|
||||
// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
|
||||
func NewConn(client *Client, dstID messages.PeerID, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
|
||||
c := &Conn{
|
||||
client: client,
|
||||
dstID: dstID,
|
||||
messageChan: messageChan,
|
||||
instanceURL: instanceURL,
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
return c.client.writeTo(c, c.dstID, p)
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
msg, ok := <-c.messageChan
|
||||
if !ok {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
n = copy(b, msg.Payload)
|
||||
msg.Free()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.client.closeConn(c, c.dstID)
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return c.client.relayConn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.instanceURL
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("SetDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("SetReadDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
//TODO implement me
|
||||
panic("SetReadDeadline is not implemented")
|
||||
}
|
||||
7
shared/relay/client/dialer/net/err.go
Normal file
7
shared/relay/client/dialer/net/err.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package net
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrClosedByServer = errors.New("closed by server")
|
||||
)
|
||||
97
shared/relay/client/dialer/quic/conn.go
Normal file
97
shared/relay/client/dialer/quic/conn.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
netErr "github.com/netbirdio/netbird/shared/relay/client/dialer/net"
|
||||
)
|
||||
|
||||
const (
|
||||
Network = "quic"
|
||||
)
|
||||
|
||||
type Addr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a Addr) Network() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
func (a Addr) String() string {
|
||||
return a.addr
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
session quic.Connection
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewConn(session quic.Connection) net.Conn {
|
||||
return &Conn{
|
||||
session: session,
|
||||
ctx: context.Background(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
dgram, err := c.session.ReceiveDatagram(c.ctx)
|
||||
if err != nil {
|
||||
return 0, c.remoteCloseErrHandling(err)
|
||||
}
|
||||
|
||||
n = copy(b, dgram)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
err := c.session.SendDatagram(b)
|
||||
if err != nil {
|
||||
err = c.remoteCloseErrHandling(err)
|
||||
log.Errorf("failed to write to QUIC stream: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.session.RemoteAddr()
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
if c.session != nil {
|
||||
return c.session.LocalAddr()
|
||||
}
|
||||
return Addr{addr: "unknown"}
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetReadDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.session.CloseWithError(0, "normal closure")
|
||||
}
|
||||
|
||||
func (c *Conn) remoteCloseErrHandling(err error) error {
|
||||
var appErr *quic.ApplicationError
|
||||
if errors.As(err, &appErr) && appErr.ErrorCode == 0x0 {
|
||||
return netErr.ErrClosedByServer
|
||||
}
|
||||
return err
|
||||
}
|
||||
99
shared/relay/client/dialer/quic/quic.go
Normal file
99
shared/relay/client/dialer/quic/quic.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/quic-go/quic-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
quictls "github.com/netbirdio/netbird/relay/tls"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
}
|
||||
|
||||
func (d Dialer) Protocol() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
quicURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the base TLS config
|
||||
tlsClientConfig := quictls.ClientQUICTLSConfig()
|
||||
|
||||
// Set ServerName to hostname if not an IP address
|
||||
host, _, splitErr := net.SplitHostPort(quicURL)
|
||||
if splitErr == nil && net.ParseIP(host) == nil {
|
||||
// It's a hostname, not an IP - modify directly
|
||||
tlsClientConfig.ServerName = host
|
||||
}
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
KeepAlivePeriod: 30 * time.Second,
|
||||
MaxIdleTimeout: 4 * time.Minute,
|
||||
EnableDatagrams: true,
|
||||
InitialPacketSize: 1452,
|
||||
}
|
||||
|
||||
udpConn, err := nbnet.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
|
||||
if err != nil {
|
||||
log.Errorf("failed to listen on UDP: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", quicURL)
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolve UDP address: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := quic.Dial(ctx, udpConn, udpAddr, tlsClientConfig, quicConfig)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to dial to Relay server via QUIC '%s': %s", quicURL, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := NewConn(session)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
var host string
|
||||
var defaultPort string
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(address, "rels://"):
|
||||
host = address[7:]
|
||||
defaultPort = "443"
|
||||
case strings.HasPrefix(address, "rel://"):
|
||||
host = address[6:]
|
||||
defaultPort = "80"
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported scheme: %s", address)
|
||||
}
|
||||
|
||||
finalHost, finalPort, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "missing port") {
|
||||
return host + ":" + defaultPort, nil
|
||||
}
|
||||
|
||||
// return any other split error as is
|
||||
return "", err
|
||||
}
|
||||
|
||||
return finalHost + ":" + finalPort, nil
|
||||
}
|
||||
98
shared/relay/client/dialer/race_dialer.go
Normal file
98
shared/relay/client/dialer/race_dialer.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultConnectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type DialeFn interface {
|
||||
Dial(ctx context.Context, address string) (net.Conn, error)
|
||||
Protocol() string
|
||||
}
|
||||
|
||||
type dialResult struct {
|
||||
Conn net.Conn
|
||||
Protocol string
|
||||
Err error
|
||||
}
|
||||
|
||||
type RaceDial struct {
|
||||
log *log.Entry
|
||||
serverURL string
|
||||
dialerFns []DialeFn
|
||||
connectionTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewRaceDial(log *log.Entry, connectionTimeout time.Duration, serverURL string, dialerFns ...DialeFn) *RaceDial {
|
||||
return &RaceDial{
|
||||
log: log,
|
||||
serverURL: serverURL,
|
||||
dialerFns: dialerFns,
|
||||
connectionTimeout: connectionTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RaceDial) Dial() (net.Conn, error) {
|
||||
connChan := make(chan dialResult, len(r.dialerFns))
|
||||
winnerConn := make(chan net.Conn, 1)
|
||||
abortCtx, abort := context.WithCancel(context.Background())
|
||||
defer abort()
|
||||
|
||||
for _, dfn := range r.dialerFns {
|
||||
go r.dial(dfn, abortCtx, connChan)
|
||||
}
|
||||
|
||||
go r.processResults(connChan, winnerConn, abort)
|
||||
|
||||
conn, ok := <-winnerConn
|
||||
if !ok {
|
||||
return nil, errors.New("failed to dial to Relay server on any protocol")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (r *RaceDial) dial(dfn DialeFn, abortCtx context.Context, connChan chan dialResult) {
|
||||
ctx, cancel := context.WithTimeout(abortCtx, r.connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
r.log.Infof("dialing Relay server via %s", dfn.Protocol())
|
||||
conn, err := dfn.Dial(ctx, r.serverURL)
|
||||
connChan <- dialResult{Conn: conn, Protocol: dfn.Protocol(), Err: err}
|
||||
}
|
||||
|
||||
func (r *RaceDial) processResults(connChan chan dialResult, winnerConn chan net.Conn, abort context.CancelFunc) {
|
||||
var hasWinner bool
|
||||
for i := 0; i < len(r.dialerFns); i++ {
|
||||
dr := <-connChan
|
||||
if dr.Err != nil {
|
||||
if errors.Is(dr.Err, context.Canceled) {
|
||||
r.log.Infof("connection attempt aborted via: %s", dr.Protocol)
|
||||
} else {
|
||||
r.log.Errorf("failed to dial via %s: %s", dr.Protocol, dr.Err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if hasWinner {
|
||||
if cerr := dr.Conn.Close(); cerr != nil {
|
||||
r.log.Warnf("failed to close connection via %s: %s", dr.Protocol, cerr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
r.log.Infof("successfully dialed via: %s", dr.Protocol)
|
||||
|
||||
abort()
|
||||
hasWinner = true
|
||||
winnerConn <- dr.Conn
|
||||
}
|
||||
close(winnerConn)
|
||||
}
|
||||
252
shared/relay/client/dialer/race_dialer_test.go
Normal file
252
shared/relay/client/dialer/race_dialer_test.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type MockAddr struct {
|
||||
network string
|
||||
}
|
||||
|
||||
func (m *MockAddr) Network() string {
|
||||
return m.network
|
||||
}
|
||||
|
||||
func (m *MockAddr) String() string {
|
||||
return "1.2.3.4"
|
||||
}
|
||||
|
||||
// MockDialer is a mock implementation of DialeFn
|
||||
type MockDialer struct {
|
||||
dialFunc func(ctx context.Context, address string) (net.Conn, error)
|
||||
protocolStr string
|
||||
}
|
||||
|
||||
func (m *MockDialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
return m.dialFunc(ctx, address)
|
||||
}
|
||||
|
||||
func (m *MockDialer) Protocol() string {
|
||||
return m.protocolStr
|
||||
}
|
||||
|
||||
// MockConn implements net.Conn for testing
|
||||
type MockConn struct {
|
||||
remoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (m *MockConn) Read(b []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
return m.remoteAddr
|
||||
}
|
||||
|
||||
func (m *MockConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRaceDialEmptyDialers(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error with empty dialers, got nil")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Errorf("Expected nil connection with empty dialers, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialSingleSuccessfulDialer(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
proto := "test-protocol"
|
||||
|
||||
mockConn := &MockConn{
|
||||
remoteAddr: &MockAddr{network: proto},
|
||||
}
|
||||
|
||||
mockDialer := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return mockConn, nil
|
||||
},
|
||||
protocolStr: proto,
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Errorf("Expected non-nil connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialMultipleDialersWithOneSuccess(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
proto2 := "protocol2"
|
||||
|
||||
mockConn2 := &MockConn{
|
||||
remoteAddr: &MockAddr{network: proto2},
|
||||
}
|
||||
|
||||
mockDialer1 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return nil, errors.New("first dialer failed")
|
||||
},
|
||||
protocolStr: "proto1",
|
||||
}
|
||||
|
||||
mockDialer2 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return mockConn2, nil
|
||||
},
|
||||
protocolStr: "proto2",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if conn.RemoteAddr().Network() != proto2 {
|
||||
t.Errorf("Expected connection with protocol %s, got %s", proto2, conn.RemoteAddr().Network())
|
||||
}
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
func TestRaceDialTimeout(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
mockDialer := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
},
|
||||
protocolStr: "proto1",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, 3*time.Second, serverURL, mockDialer)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, got nil")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Errorf("Expected nil connection, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialAllDialersFail(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
|
||||
mockDialer1 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return nil, errors.New("first dialer failed")
|
||||
},
|
||||
protocolStr: "protocol1",
|
||||
}
|
||||
|
||||
mockDialer2 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
return nil, errors.New("second dialer failed")
|
||||
},
|
||||
protocolStr: "protocol2",
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err == nil {
|
||||
t.Errorf("Expected an error, got nil")
|
||||
}
|
||||
if conn != nil {
|
||||
t.Errorf("Expected nil connection, got %v", conn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRaceDialFirstSuccessfulDialerWins(t *testing.T) {
|
||||
logger := logrus.NewEntry(logrus.New())
|
||||
serverURL := "test.server.com"
|
||||
proto1 := "protocol1"
|
||||
proto2 := "protocol2"
|
||||
|
||||
mockConn1 := &MockConn{
|
||||
remoteAddr: &MockAddr{network: proto1},
|
||||
}
|
||||
|
||||
mockConn2 := &MockConn{
|
||||
remoteAddr: &MockAddr{network: proto2},
|
||||
}
|
||||
|
||||
mockDialer1 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
time.Sleep(1 * time.Second)
|
||||
return mockConn1, nil
|
||||
},
|
||||
protocolStr: proto1,
|
||||
}
|
||||
|
||||
mock2err := make(chan error)
|
||||
mockDialer2 := &MockDialer{
|
||||
dialFunc: func(ctx context.Context, address string) (net.Conn, error) {
|
||||
<-ctx.Done()
|
||||
mock2err <- ctx.Err()
|
||||
return mockConn2, ctx.Err()
|
||||
},
|
||||
protocolStr: proto2,
|
||||
}
|
||||
|
||||
rd := NewRaceDial(logger, DefaultConnectionTimeout, serverURL, mockDialer1, mockDialer2)
|
||||
conn, err := rd.Dial()
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
if conn == nil {
|
||||
t.Errorf("Expected non-nil connection")
|
||||
}
|
||||
if conn != mockConn1 {
|
||||
t.Errorf("Expected first connection, got %v", conn)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Errorf("Timed out waiting for second dialer to finish")
|
||||
case err := <-mock2err:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
17
shared/relay/client/dialer/ws/addr.go
Normal file
17
shared/relay/client/dialer/ws/addr.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package ws
|
||||
|
||||
const (
|
||||
Network = "ws"
|
||||
)
|
||||
|
||||
type WebsocketAddr struct {
|
||||
addr string
|
||||
}
|
||||
|
||||
func (a WebsocketAddr) Network() string {
|
||||
return Network
|
||||
}
|
||||
|
||||
func (a WebsocketAddr) String() string {
|
||||
return a.addr
|
||||
}
|
||||
67
shared/relay/client/dialer/ws/conn.go
Normal file
67
shared/relay/client/dialer/ws/conn.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
ctx context.Context
|
||||
*websocket.Conn
|
||||
remoteAddr WebsocketAddr
|
||||
}
|
||||
|
||||
func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
|
||||
return &Conn{
|
||||
ctx: context.Background(),
|
||||
Conn: wsConn,
|
||||
remoteAddr: WebsocketAddr{serverAddress},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
t, ioReader, err := c.Conn.Reader(c.ctx)
|
||||
if err != nil {
|
||||
// todo use ErrClosedByServer
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if t != websocket.MessageBinary {
|
||||
return 0, fmt.Errorf("unexpected message type")
|
||||
}
|
||||
|
||||
return ioReader.Read(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
func (c *Conn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *Conn) LocalAddr() net.Addr {
|
||||
return WebsocketAddr{addr: "unknown"}
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetReadDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetWriteDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
return fmt.Errorf("SetDeadline is not implemented")
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
return c.Conn.CloseNow()
|
||||
}
|
||||
90
shared/relay/client/dialer/ws/ws.go
Normal file
90
shared/relay/client/dialer/ws/ws.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/server/listener/ws"
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
nbnet "github.com/netbirdio/netbird/util/net"
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
}
|
||||
|
||||
func (d Dialer) Protocol() string {
|
||||
return "WS"
|
||||
}
|
||||
|
||||
func (d Dialer) Dial(ctx context.Context, address string) (net.Conn, error) {
|
||||
wsURL, err := prepareURL(address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts := &websocket.DialOptions{
|
||||
HTTPClient: httpClientNbDialer(),
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(wsURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsedURL.Path = ws.URLPath
|
||||
|
||||
wsConn, resp, err := websocket.Dial(ctx, parsedURL.String(), opts)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, err
|
||||
}
|
||||
log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
|
||||
return nil, err
|
||||
}
|
||||
if resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
|
||||
conn := NewConn(wsConn, address)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func prepareURL(address string) (string, error) {
|
||||
if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") {
|
||||
return "", fmt.Errorf("unsupported scheme: %s", address)
|
||||
}
|
||||
|
||||
return strings.Replace(address, "rel", "ws", 1), nil
|
||||
}
|
||||
|
||||
func httpClientNbDialer() *http.Client {
|
||||
customDialer := nbnet.NewDialer()
|
||||
|
||||
certPool, err := x509.SystemCertPool()
|
||||
if err != nil || certPool == nil {
|
||||
log.Debugf("System cert pool not available; falling back to embedded cert, error: %v", err)
|
||||
certPool = embeddedroots.Get()
|
||||
}
|
||||
|
||||
customTransport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return customDialer.DialContext(ctx, network, addr)
|
||||
},
|
||||
TLSClientConfig: &tls.Config{
|
||||
RootCAs: certPool,
|
||||
},
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: customTransport,
|
||||
}
|
||||
}
|
||||
12
shared/relay/client/doc.go
Normal file
12
shared/relay/client/doc.go
Normal file
@@ -0,0 +1,12 @@
|
||||
/*
|
||||
Package client contains the implementation of the Relay client.
|
||||
|
||||
The Relay client is responsible for establishing a connection with the Relay server and sending and receiving messages,
|
||||
Keep persistent connection with the Relay server and handle the connection issues.
|
||||
It uses the WebSocket protocol for communication and optionally supports TLS (Transport Layer Security).
|
||||
|
||||
If a peer wants to communicate with a peer on a different relay server, the manager will establish a new connection to
|
||||
the relay server. The connection with these relay servers will be closed if there is no active connection. The peers
|
||||
negotiate the common relay instance via signaling service.
|
||||
*/
|
||||
package client
|
||||
10
shared/relay/client/go.sum
Normal file
10
shared/relay/client/go.sum
Normal file
@@ -0,0 +1,10 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
149
shared/relay/client/guard.go
Normal file
149
shared/relay/client/guard.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// TODO: make it configurable, the manager should validate all configurable parameters
|
||||
reconnectingTimeout = 60 * time.Second
|
||||
)
|
||||
|
||||
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
|
||||
type Guard struct {
|
||||
// OnNewRelayClient is a channel that is used to notify the relay manager about a new relay client instance.
|
||||
OnNewRelayClient chan *Client
|
||||
OnReconnected chan struct{}
|
||||
serverPicker *ServerPicker
|
||||
}
|
||||
|
||||
// NewGuard creates a new guard for the relay client.
|
||||
func NewGuard(sp *ServerPicker) *Guard {
|
||||
g := &Guard{
|
||||
OnNewRelayClient: make(chan *Client, 1),
|
||||
OnReconnected: make(chan struct{}, 1),
|
||||
serverPicker: sp,
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
// StartReconnectTrys is called when the relay client is disconnected from the relay server.
|
||||
// It attempts to reconnect to the relay server. The function first tries a quick reconnect
|
||||
// to the same server that was used before, if the server URL is still valid. If the quick
|
||||
// reconnect fails, it starts a ticker to periodically attempt server picking until it
|
||||
// succeeds or the context is done.
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: The context to control the lifecycle of the reconnection attempts.
|
||||
// - relayClient: The relay client instance that was disconnected.
|
||||
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
|
||||
func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
|
||||
// try to reconnect to the same server
|
||||
if ok := g.tryToQuickReconnect(ctx, relayClient); ok {
|
||||
g.notifyReconnected()
|
||||
return
|
||||
}
|
||||
|
||||
// start a ticker to pick a new server
|
||||
ticker := exponentTicker(ctx)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := g.retry(ctx); err != nil {
|
||||
log.Errorf("failed to pick new Relay server: %s", err)
|
||||
continue
|
||||
}
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) tryToQuickReconnect(parentCtx context.Context, rc *Client) bool {
|
||||
if rc == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !g.isServerURLStillValid(rc) {
|
||||
return false
|
||||
}
|
||||
|
||||
if cancelled := waiteBeforeRetry(parentCtx); !cancelled {
|
||||
return false
|
||||
}
|
||||
|
||||
log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
|
||||
|
||||
if err := rc.Connect(parentCtx); err != nil {
|
||||
log.Errorf("failed to reconnect to relay server: %s", err)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (g *Guard) retry(ctx context.Context) error {
|
||||
log.Infof("try to pick up a new Relay server")
|
||||
relayClient, err := g.serverPicker.PickServer(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// prevent to work with a deprecated Relay client instance
|
||||
g.drainRelayClientChan()
|
||||
|
||||
g.OnNewRelayClient <- relayClient
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Guard) drainRelayClientChan() {
|
||||
select {
|
||||
case <-g.OnNewRelayClient:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (g *Guard) isServerURLStillValid(rc *Client) bool {
|
||||
for _, url := range g.serverPicker.ServerURLs.Load().([]string) {
|
||||
if url == rc.connectionURL {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (g *Guard) notifyReconnected() {
|
||||
select {
|
||||
case g.OnReconnected <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func exponentTicker(ctx context.Context) *backoff.Ticker {
|
||||
bo := backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 2 * time.Second,
|
||||
Multiplier: 2,
|
||||
MaxInterval: reconnectingTimeout,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
|
||||
return backoff.NewTicker(bo)
|
||||
}
|
||||
|
||||
func waiteBeforeRetry(ctx context.Context) bool {
|
||||
timer := time.NewTimer(1500 * time.Millisecond)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
return true
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
404
shared/relay/client/manager.go
Normal file
404
shared/relay/client/manager.go
Normal file
@@ -0,0 +1,404 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
relayAuth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
)
|
||||
|
||||
var (
|
||||
relayCleanupInterval = 60 * time.Second
|
||||
keepUnusedServerTime = 5 * time.Second
|
||||
|
||||
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
|
||||
)
|
||||
|
||||
// RelayTrack hold the relay clients for the foreign relay servers.
|
||||
// With the mutex can ensure we can open new connection in case the relay connection has been established with
|
||||
// the relay server.
|
||||
type RelayTrack struct {
|
||||
sync.RWMutex
|
||||
relayClient *Client
|
||||
err error
|
||||
created time.Time
|
||||
}
|
||||
|
||||
func NewRelayTrack() *RelayTrack {
|
||||
return &RelayTrack{
|
||||
created: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
type OnServerCloseListener func()
|
||||
|
||||
// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
|
||||
// and automatically reconnect to them in case disconnection.
|
||||
// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
|
||||
// different relay server, the manager will establish a new connection to the relay server. The connection with these
|
||||
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
|
||||
// unused relay connection and close it.
|
||||
type Manager struct {
|
||||
ctx context.Context
|
||||
peerID string
|
||||
running bool
|
||||
tokenStore *relayAuth.TokenStore
|
||||
serverPicker *ServerPicker
|
||||
|
||||
relayClient *Client
|
||||
// the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
|
||||
relayClientMu sync.RWMutex
|
||||
reconnectGuard *Guard
|
||||
|
||||
relayClients map[string]*RelayTrack
|
||||
relayClientsMutex sync.RWMutex
|
||||
|
||||
onDisconnectedListeners map[string]*list.List
|
||||
onReconnectedListenerFn func()
|
||||
listenerLock sync.Mutex
|
||||
}
|
||||
|
||||
// NewManager creates a new manager instance.
|
||||
// The serverURL address can be empty. In this case, the manager will not serve.
|
||||
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
|
||||
tokenStore := &relayAuth.TokenStore{}
|
||||
|
||||
m := &Manager{
|
||||
ctx: ctx,
|
||||
peerID: peerID,
|
||||
tokenStore: tokenStore,
|
||||
serverPicker: &ServerPicker{
|
||||
TokenStore: tokenStore,
|
||||
PeerID: peerID,
|
||||
},
|
||||
relayClients: make(map[string]*RelayTrack),
|
||||
onDisconnectedListeners: make(map[string]*list.List),
|
||||
}
|
||||
m.serverPicker.ServerURLs.Store(serverURLs)
|
||||
m.reconnectGuard = NewGuard(m.serverPicker)
|
||||
return m
|
||||
}
|
||||
|
||||
// Serve starts the manager, attempting to establish a connection with the relay server.
|
||||
// If the connection fails, it will keep trying to reconnect in the background.
|
||||
// Additionally, it starts a cleanup loop to remove unused relay connections.
|
||||
// The manager will automatically reconnect to the relay server in case of disconnection.
|
||||
func (m *Manager) Serve() error {
|
||||
if m.running {
|
||||
return fmt.Errorf("manager already serving")
|
||||
}
|
||||
m.running = true
|
||||
log.Debugf("starting relay client manager with %v relay servers", m.serverPicker.ServerURLs.Load())
|
||||
|
||||
client, err := m.serverPicker.PickServer(m.ctx)
|
||||
if err != nil {
|
||||
go m.reconnectGuard.StartReconnectTrys(m.ctx, nil)
|
||||
} else {
|
||||
m.storeClient(client)
|
||||
}
|
||||
|
||||
go m.listenGuardEvent(m.ctx)
|
||||
go m.startCleanupLoop()
|
||||
return err
|
||||
}
|
||||
|
||||
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
|
||||
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
|
||||
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
|
||||
func (m *Manager) OpenConn(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return nil, ErrRelayClientNotConnected
|
||||
}
|
||||
|
||||
foreign, err := m.isForeignServer(serverAddress)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
netConn net.Conn
|
||||
)
|
||||
if !foreign {
|
||||
log.Debugf("open peer connection via permanent server: %s", peerKey)
|
||||
netConn, err = m.relayClient.OpenConn(ctx, peerKey)
|
||||
} else {
|
||||
log.Debugf("open peer connection via foreign server: %s", serverAddress)
|
||||
netConn, err = m.openConnVia(ctx, serverAddress, peerKey)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return netConn, err
|
||||
}
|
||||
|
||||
// Ready returns true if the home Relay client is connected to the relay server.
|
||||
func (m *Manager) Ready() bool {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return false
|
||||
}
|
||||
return m.relayClient.Ready()
|
||||
}
|
||||
|
||||
func (m *Manager) SetOnReconnectedListener(f func()) {
|
||||
m.listenerLock.Lock()
|
||||
defer m.listenerLock.Unlock()
|
||||
|
||||
m.onReconnectedListenerFn = f
|
||||
}
|
||||
|
||||
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
|
||||
// closed.
|
||||
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return ErrRelayClientNotConnected
|
||||
}
|
||||
|
||||
foreign, err := m.isForeignServer(serverAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var listenerAddr string
|
||||
if foreign {
|
||||
listenerAddr = serverAddress
|
||||
} else {
|
||||
listenerAddr = m.relayClient.connectionURL
|
||||
}
|
||||
m.addListener(listenerAddr, onClosedListener)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
|
||||
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
|
||||
func (m *Manager) RelayInstanceAddress() (string, error) {
|
||||
m.relayClientMu.RLock()
|
||||
defer m.relayClientMu.RUnlock()
|
||||
|
||||
if m.relayClient == nil {
|
||||
return "", ErrRelayClientNotConnected
|
||||
}
|
||||
return m.relayClient.ServerInstanceURL()
|
||||
}
|
||||
|
||||
// ServerURLs returns the addresses of the relay servers.
|
||||
func (m *Manager) ServerURLs() []string {
|
||||
return m.serverPicker.ServerURLs.Load().([]string)
|
||||
}
|
||||
|
||||
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
|
||||
// Relay service.
|
||||
func (m *Manager) HasRelayAddress() bool {
|
||||
return len(m.serverPicker.ServerURLs.Load().([]string)) > 0
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateServerURLs(serverURLs []string) {
|
||||
log.Infof("update relay server URLs: %v", serverURLs)
|
||||
m.serverPicker.ServerURLs.Store(serverURLs)
|
||||
}
|
||||
|
||||
// UpdateToken updates the token in the token store.
|
||||
func (m *Manager) UpdateToken(token *relayAuth.Token) error {
|
||||
return m.tokenStore.UpdateToken(token)
|
||||
}
|
||||
|
||||
func (m *Manager) openConnVia(ctx context.Context, serverAddress, peerKey string) (net.Conn, error) {
|
||||
// check if already has a connection to the desired relay server
|
||||
m.relayClientsMutex.RLock()
|
||||
rt, ok := m.relayClients[serverAddress]
|
||||
if ok {
|
||||
rt.RLock()
|
||||
m.relayClientsMutex.RUnlock()
|
||||
defer rt.RUnlock()
|
||||
if rt.err != nil {
|
||||
return nil, rt.err
|
||||
}
|
||||
return rt.relayClient.OpenConn(ctx, peerKey)
|
||||
}
|
||||
m.relayClientsMutex.RUnlock()
|
||||
|
||||
// if not, establish a new connection but check it again (because changed the lock type) before starting the
|
||||
// connection
|
||||
m.relayClientsMutex.Lock()
|
||||
rt, ok = m.relayClients[serverAddress]
|
||||
if ok {
|
||||
rt.RLock()
|
||||
m.relayClientsMutex.Unlock()
|
||||
defer rt.RUnlock()
|
||||
if rt.err != nil {
|
||||
return nil, rt.err
|
||||
}
|
||||
return rt.relayClient.OpenConn(ctx, peerKey)
|
||||
}
|
||||
|
||||
// create a new relay client and store it in the relayClients map
|
||||
rt = NewRelayTrack()
|
||||
rt.Lock()
|
||||
m.relayClients[serverAddress] = rt
|
||||
m.relayClientsMutex.Unlock()
|
||||
|
||||
relayClient := NewClient(serverAddress, m.tokenStore, m.peerID)
|
||||
err := relayClient.Connect(m.ctx)
|
||||
if err != nil {
|
||||
rt.err = err
|
||||
rt.Unlock()
|
||||
m.relayClientsMutex.Lock()
|
||||
delete(m.relayClients, serverAddress)
|
||||
m.relayClientsMutex.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
// if connection closed then delete the relay client from the list
|
||||
relayClient.SetOnDisconnectListener(m.onServerDisconnected)
|
||||
rt.relayClient = relayClient
|
||||
rt.Unlock()
|
||||
|
||||
conn, err := relayClient.OpenConn(ctx, peerKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (m *Manager) onServerConnected() {
|
||||
m.listenerLock.Lock()
|
||||
defer m.listenerLock.Unlock()
|
||||
|
||||
if m.onReconnectedListenerFn == nil {
|
||||
return
|
||||
}
|
||||
go m.onReconnectedListenerFn()
|
||||
}
|
||||
|
||||
// onServerDisconnected start to reconnection for home server only
|
||||
func (m *Manager) onServerDisconnected(serverAddress string) {
|
||||
m.relayClientMu.Lock()
|
||||
if serverAddress == m.relayClient.connectionURL {
|
||||
go func(client *Client) {
|
||||
m.reconnectGuard.StartReconnectTrys(m.ctx, client)
|
||||
}(m.relayClient)
|
||||
}
|
||||
m.relayClientMu.Unlock()
|
||||
|
||||
m.notifyOnDisconnectListeners(serverAddress)
|
||||
}
|
||||
|
||||
func (m *Manager) listenGuardEvent(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-m.reconnectGuard.OnReconnected:
|
||||
m.onServerConnected()
|
||||
case rc := <-m.reconnectGuard.OnNewRelayClient:
|
||||
m.storeClient(rc)
|
||||
m.onServerConnected()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) storeClient(client *Client) {
|
||||
m.relayClientMu.Lock()
|
||||
defer m.relayClientMu.Unlock()
|
||||
|
||||
m.relayClient = client
|
||||
m.relayClient.SetOnDisconnectListener(m.onServerDisconnected)
|
||||
}
|
||||
|
||||
func (m *Manager) isForeignServer(address string) (bool, error) {
|
||||
rAddr, err := m.relayClient.ServerInstanceURL()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("relay client not connected")
|
||||
}
|
||||
return rAddr != address, nil
|
||||
}
|
||||
|
||||
func (m *Manager) startCleanupLoop() {
|
||||
ticker := time.NewTicker(relayCleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.cleanUpUnusedRelays()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) cleanUpUnusedRelays() {
|
||||
m.relayClientsMutex.Lock()
|
||||
defer m.relayClientsMutex.Unlock()
|
||||
|
||||
for addr, rt := range m.relayClients {
|
||||
rt.Lock()
|
||||
// if the connection failed to the server the relay client will be nil
|
||||
// but the instance will be kept in the relayClients until the next locking
|
||||
if rt.err != nil {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(rt.created) <= keepUnusedServerTime {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
if rt.relayClient.HasConns() {
|
||||
rt.Unlock()
|
||||
continue
|
||||
}
|
||||
rt.relayClient.SetOnDisconnectListener(nil)
|
||||
go func() {
|
||||
_ = rt.relayClient.Close()
|
||||
}()
|
||||
log.Debugf("clean up unused relay server connection: %s", addr)
|
||||
delete(m.relayClients, addr)
|
||||
rt.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
|
||||
m.listenerLock.Lock()
|
||||
defer m.listenerLock.Unlock()
|
||||
l, ok := m.onDisconnectedListeners[serverAddress]
|
||||
if !ok {
|
||||
l = list.New()
|
||||
}
|
||||
for e := l.Front(); e != nil; e = e.Next() {
|
||||
if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() {
|
||||
return
|
||||
}
|
||||
}
|
||||
l.PushBack(onClosedListener)
|
||||
m.onDisconnectedListeners[serverAddress] = l
|
||||
}
|
||||
|
||||
func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
|
||||
m.listenerLock.Lock()
|
||||
defer m.listenerLock.Unlock()
|
||||
|
||||
l, ok := m.onDisconnectedListeners[serverAddress]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for e := l.Front(); e != nil; e = e.Next() {
|
||||
go e.Value.(OnServerCloseListener)()
|
||||
}
|
||||
delete(m.onDisconnectedListeners, serverAddress)
|
||||
}
|
||||
469
shared/relay/client/manager_test.go
Normal file
469
shared/relay/client/manager_test.go
Normal file
@@ -0,0 +1,469 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/relay/auth/allow"
|
||||
"github.com/netbirdio/netbird/relay/server"
|
||||
)
|
||||
|
||||
func TestEmptyURL(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
mgr := NewManager(ctx, nil, "alice")
|
||||
err := mgr.Serve()
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeignConn(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
lstCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
|
||||
srv1, err := server.NewServer(server.Config{
|
||||
Meter: otel.Meter(""),
|
||||
ExposedAddress: lstCfg1.Address,
|
||||
TLSSupport: false,
|
||||
AuthValidator: &allow.Auth{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv1.Listen(lstCfg1)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
srvCfg2 := server.ListenerConfig{
|
||||
Address: "localhost:2234",
|
||||
}
|
||||
srv2, err := server.NewServer(server.Config{
|
||||
Meter: otel.Meter(""),
|
||||
ExposedAddress: srvCfg2.Address,
|
||||
TLSSupport: false,
|
||||
AuthValidator: &allow.Auth{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv2.Listen(srvCfg2)
|
||||
if err != nil {
|
||||
errChan2 <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv2.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan2); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
clientAlice := NewManager(mCtx, toURL(lstCfg1), "alice")
|
||||
if err := clientAlice.Serve(); err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
clientBob := NewManager(mCtx, toURL(srvCfg2), "bob")
|
||||
if err := clientBob.Serve(); err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
bobsSrvAddr, err := clientBob.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get relay address: %s", err)
|
||||
}
|
||||
connAliceToBob, err := clientAlice.OpenConn(ctx, bobsSrvAddr, "bob")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
connBobToAlice, err := clientBob.OpenConn(ctx, bobsSrvAddr, "alice")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
payload := "hello bob, I am alice"
|
||||
_, err = connAliceToBob.Write([]byte(payload))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
buf := make([]byte, 65535)
|
||||
n, err := connBobToAlice.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
_, err = connBobToAlice.Write(buf[:n])
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write to channel: %s", err)
|
||||
}
|
||||
|
||||
n, err = connAliceToBob.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to read from channel: %s", err)
|
||||
}
|
||||
|
||||
if payload != string(buf[:n]) {
|
||||
t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeginConnClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv1.Listen(srvCfg1)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv1.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
srvCfg2 := server.ListenerConfig{
|
||||
Address: "localhost:2234",
|
||||
}
|
||||
srv2, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() {
|
||||
err := srv2.Listen(srvCfg2)
|
||||
if err != nil {
|
||||
errChan2 <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv2.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan2); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
mgrBob := NewManager(mCtx, toURL(srvCfg2), "bob")
|
||||
if err := mgrBob.Serve(); err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
mgr := NewManager(mCtx, toURL(srvCfg1), "alice")
|
||||
err = mgr.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
conn, err := mgr.OpenConn(ctx, toURL(srvCfg2)[0], "bob")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to close connection: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeignAutoClose(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
relayCleanupInterval = 1 * time.Second
|
||||
keepUnusedServerTime = 2 * time.Second
|
||||
|
||||
srvCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv1, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
t.Log("binding server 1.")
|
||||
if err := srv1.Listen(srvCfg1); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
t.Logf("closing server 1.")
|
||||
if err := srv1.Shutdown(ctx); err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
t.Logf("server 1. closed")
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
srvCfg2 := server.ListenerConfig{
|
||||
Address: "localhost:2234",
|
||||
}
|
||||
srv2, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan2 := make(chan error, 1)
|
||||
go func() {
|
||||
t.Log("binding server 2.")
|
||||
err := srv2.Listen(srvCfg2)
|
||||
if err != nil {
|
||||
errChan2 <- err
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
t.Logf("closing server 2.")
|
||||
err := srv2.Shutdown(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
t.Logf("server 2 closed.")
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan2); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
idAlice := "alice"
|
||||
t.Log("connect to server 1.")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
|
||||
err = mgr.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
// Set up a disconnect listener to track when foreign server disconnects
|
||||
foreignServerURL := toURL(srvCfg2)[0]
|
||||
disconnected := make(chan struct{})
|
||||
onDisconnect := func() {
|
||||
select {
|
||||
case disconnected <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
t.Log("open connection to another peer")
|
||||
if _, err = mgr.OpenConn(ctx, foreignServerURL, "anotherpeer"); err == nil {
|
||||
t.Fatalf("should have failed to open connection to another peer")
|
||||
}
|
||||
|
||||
// Add the disconnect listener after the connection attempt
|
||||
if err := mgr.AddCloseListener(foreignServerURL, onDisconnect); err != nil {
|
||||
t.Logf("failed to add close listener (expected if connection failed): %s", err)
|
||||
}
|
||||
|
||||
// Wait for cleanup to happen
|
||||
timeout := relayCleanupInterval + keepUnusedServerTime + 2*time.Second
|
||||
t.Logf("waiting for relay cleanup: %s", timeout)
|
||||
|
||||
select {
|
||||
case <-disconnected:
|
||||
t.Log("foreign relay connection cleaned up successfully")
|
||||
case <-time.After(timeout):
|
||||
t.Log("timeout waiting for cleanup - this might be expected if connection never established")
|
||||
}
|
||||
|
||||
t.Logf("closing manager")
|
||||
}
|
||||
|
||||
func TestAutoReconnect(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
srvCfg := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv, err := server.NewServer(serverCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
if err := srv.Listen(srvCfg); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
err := srv.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
clientBob := NewManager(mCtx, toURL(srvCfg), "bob")
|
||||
err = clientBob.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
|
||||
err = clientAlice.Serve()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
ra, err := clientAlice.RelayInstanceAddress()
|
||||
if err != nil {
|
||||
t.Errorf("failed to get relay address: %s", err)
|
||||
}
|
||||
conn, err := clientAlice.OpenConn(ctx, ra, "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
t.Log("closing client relay connection")
|
||||
// todo figure out moc server
|
||||
_ = clientAlice.relayClient.relayConn.Close()
|
||||
t.Log("start test reading")
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Errorf("unexpected reading from closed connection")
|
||||
}
|
||||
|
||||
log.Infof("waiting for reconnection")
|
||||
time.Sleep(reconnectingTimeout + 1*time.Second)
|
||||
|
||||
log.Infof("reopent the connection")
|
||||
_, err = clientAlice.OpenConn(ctx, ra, "bob")
|
||||
if err != nil {
|
||||
t.Errorf("failed to open channel: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNotifierDoubleAdd(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
listenerCfg1 := server.ListenerConfig{
|
||||
Address: "localhost:1234",
|
||||
}
|
||||
srv, err := server.NewServer(server.Config{
|
||||
Meter: otel.Meter(""),
|
||||
ExposedAddress: listenerCfg1.Address,
|
||||
TLSSupport: false,
|
||||
AuthValidator: &allow.Auth{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create server: %s", err)
|
||||
}
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
if err := srv.Listen(listenerCfg1); err != nil {
|
||||
errChan <- err
|
||||
}
|
||||
}()
|
||||
|
||||
defer func() {
|
||||
if err := srv.Shutdown(ctx); err != nil {
|
||||
t.Errorf("failed to close server: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := waitForServerToStart(errChan); err != nil {
|
||||
t.Fatalf("failed to start server: %s", err)
|
||||
}
|
||||
|
||||
log.Debugf("connect by alice")
|
||||
mCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
clientBob := NewManager(mCtx, toURL(listenerCfg1), "bob")
|
||||
if err = clientBob.Serve(); err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
clientAlice := NewManager(mCtx, toURL(listenerCfg1), "alice")
|
||||
if err = clientAlice.Serve(); err != nil {
|
||||
t.Fatalf("failed to serve manager: %s", err)
|
||||
}
|
||||
|
||||
conn1, err := clientAlice.OpenConn(ctx, clientAlice.ServerURLs()[0], "bob")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind channel: %s", err)
|
||||
}
|
||||
|
||||
fnCloseListener := OnServerCloseListener(func() {
|
||||
log.Infof("close listener")
|
||||
})
|
||||
|
||||
err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add close listener: %s", err)
|
||||
}
|
||||
|
||||
err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add close listener: %s", err)
|
||||
}
|
||||
|
||||
err = conn1.Close()
|
||||
if err != nil {
|
||||
t.Errorf("failed to close connection: %s", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func toURL(address server.ListenerConfig) []string {
|
||||
return []string{"rel://" + address.Address}
|
||||
}
|
||||
191
shared/relay/client/peer_subscription.go
Normal file
191
shared/relay/client/peer_subscription.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
)
|
||||
|
||||
const (
|
||||
OpenConnectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type relayedConnWriter interface {
|
||||
Write(p []byte) (n int, err error)
|
||||
}
|
||||
|
||||
// PeersStateSubscription manages subscriptions to peer state changes (online/offline)
|
||||
// over a relay connection. It allows tracking peers' availability and handling offline
|
||||
// events via a callback. We get online notification from the server only once.
|
||||
type PeersStateSubscription struct {
|
||||
log *log.Entry
|
||||
relayConn relayedConnWriter
|
||||
offlineCallback func(peerIDs []messages.PeerID)
|
||||
|
||||
listenForOfflinePeers map[messages.PeerID]struct{}
|
||||
waitingPeers map[messages.PeerID]chan struct{}
|
||||
mu sync.Mutex // Mutex to protect access to waitingPeers and listenForOfflinePeers
|
||||
}
|
||||
|
||||
func NewPeersStateSubscription(log *log.Entry, relayConn relayedConnWriter, offlineCallback func(peerIDs []messages.PeerID)) *PeersStateSubscription {
|
||||
return &PeersStateSubscription{
|
||||
log: log,
|
||||
relayConn: relayConn,
|
||||
offlineCallback: offlineCallback,
|
||||
listenForOfflinePeers: make(map[messages.PeerID]struct{}),
|
||||
waitingPeers: make(map[messages.PeerID]chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// OnPeersOnline should be called when a notification is received that certain peers have come online.
|
||||
// It checks if any of the peers are being waited on and signals their availability.
|
||||
func (s *PeersStateSubscription) OnPeersOnline(peersID []messages.PeerID) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, peerID := range peersID {
|
||||
waitCh, ok := s.waitingPeers[peerID]
|
||||
if !ok {
|
||||
// If meanwhile the peer was unsubscribed, we don't need to signal it
|
||||
continue
|
||||
}
|
||||
|
||||
waitCh <- struct{}{}
|
||||
delete(s.waitingPeers, peerID)
|
||||
close(waitCh)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PeersStateSubscription) OnPeersWentOffline(peersID []messages.PeerID) {
|
||||
s.mu.Lock()
|
||||
relevantPeers := make([]messages.PeerID, 0, len(peersID))
|
||||
for _, peerID := range peersID {
|
||||
if _, ok := s.listenForOfflinePeers[peerID]; ok {
|
||||
relevantPeers = append(relevantPeers, peerID)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if len(relevantPeers) > 0 {
|
||||
s.offlineCallback(relevantPeers)
|
||||
}
|
||||
}
|
||||
|
||||
// WaitToBeOnlineAndSubscribe waits for a specific peer to come online and subscribes to its state changes.
|
||||
func (s *PeersStateSubscription) WaitToBeOnlineAndSubscribe(ctx context.Context, peerID messages.PeerID) error {
|
||||
// Check if already waiting for this peer
|
||||
s.mu.Lock()
|
||||
if _, exists := s.waitingPeers[peerID]; exists {
|
||||
s.mu.Unlock()
|
||||
return errors.New("already waiting for peer to come online")
|
||||
}
|
||||
|
||||
// Create a channel to wait for the peer to come online
|
||||
waitCh := make(chan struct{}, 1)
|
||||
s.waitingPeers[peerID] = waitCh
|
||||
s.listenForOfflinePeers[peerID] = struct{}{}
|
||||
s.mu.Unlock()
|
||||
|
||||
if err := s.subscribeStateChange(peerID); err != nil {
|
||||
s.log.Errorf("failed to subscribe to peer state: %s", err)
|
||||
s.mu.Lock()
|
||||
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
|
||||
close(waitCh)
|
||||
delete(s.waitingPeers, peerID)
|
||||
delete(s.listenForOfflinePeers, peerID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for peer to come online or context to be cancelled
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, OpenConnectionTimeout)
|
||||
defer cancel()
|
||||
select {
|
||||
case _, ok := <-waitCh:
|
||||
if !ok {
|
||||
return fmt.Errorf("wait for peer to come online has been cancelled")
|
||||
}
|
||||
|
||||
s.log.Debugf("peer %s is now online", peerID)
|
||||
return nil
|
||||
case <-timeoutCtx.Done():
|
||||
s.log.Debugf("context timed out while waiting for peer %s to come online", peerID)
|
||||
if err := s.unsubscribeStateChange([]messages.PeerID{peerID}); err != nil {
|
||||
s.log.Errorf("failed to unsubscribe from peer state: %s", err)
|
||||
}
|
||||
s.mu.Lock()
|
||||
if ch, exists := s.waitingPeers[peerID]; exists && ch == waitCh {
|
||||
close(waitCh)
|
||||
delete(s.waitingPeers, peerID)
|
||||
delete(s.listenForOfflinePeers, peerID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return timeoutCtx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PeersStateSubscription) UnsubscribeStateChange(peerIDs []messages.PeerID) error {
|
||||
msgErr := s.unsubscribeStateChange(peerIDs)
|
||||
|
||||
s.mu.Lock()
|
||||
for _, peerID := range peerIDs {
|
||||
if wch, ok := s.waitingPeers[peerID]; ok {
|
||||
close(wch)
|
||||
delete(s.waitingPeers, peerID)
|
||||
}
|
||||
|
||||
delete(s.listenForOfflinePeers, peerID)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
return msgErr
|
||||
}
|
||||
|
||||
func (s *PeersStateSubscription) Cleanup() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
for _, waitCh := range s.waitingPeers {
|
||||
close(waitCh)
|
||||
}
|
||||
|
||||
s.waitingPeers = make(map[messages.PeerID]chan struct{})
|
||||
s.listenForOfflinePeers = make(map[messages.PeerID]struct{})
|
||||
}
|
||||
|
||||
func (s *PeersStateSubscription) subscribeStateChange(peerID messages.PeerID) error {
|
||||
msgs, err := messages.MarshalSubPeerStateMsg([]messages.PeerID{peerID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, msg := range msgs {
|
||||
if _, err := s.relayConn.Write(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PeersStateSubscription) unsubscribeStateChange(peerIDs []messages.PeerID) error {
|
||||
msgs, err := messages.MarshalUnsubPeerStateMsg(peerIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var connWriteErr error
|
||||
for _, msg := range msgs {
|
||||
if _, err := s.relayConn.Write(msg); err != nil {
|
||||
connWriteErr = err
|
||||
}
|
||||
}
|
||||
return connWriteErr
|
||||
}
|
||||
99
shared/relay/client/peer_subscription_test.go
Normal file
99
shared/relay/client/peer_subscription_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/relay/messages"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type mockRelayedConn struct {
|
||||
}
|
||||
|
||||
func (m *mockRelayedConn) Write(p []byte) (n int, err error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestWaitToBeOnlineAndSubscribe_Success(t *testing.T) {
|
||||
peerID := messages.HashID("peer1")
|
||||
mockConn := &mockRelayedConn{}
|
||||
logger := logrus.New()
|
||||
logger.SetOutput(&bytes.Buffer{}) // discard log output
|
||||
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Launch wait in background
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
sub.OnPeersOnline([]messages.PeerID{peerID})
|
||||
}()
|
||||
|
||||
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWaitToBeOnlineAndSubscribe_Timeout(t *testing.T) {
|
||||
peerID := messages.HashID("peer2")
|
||||
mockConn := &mockRelayedConn{}
|
||||
logger := logrus.New()
|
||||
logger.SetOutput(&bytes.Buffer{})
|
||||
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
}
|
||||
|
||||
func TestWaitToBeOnlineAndSubscribe_Duplicate(t *testing.T) {
|
||||
peerID := messages.HashID("peer3")
|
||||
mockConn := &mockRelayedConn{}
|
||||
logger := logrus.New()
|
||||
logger.SetOutput(&bytes.Buffer{})
|
||||
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
go func() {
|
||||
_ = sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
|
||||
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
err := sub.WaitToBeOnlineAndSubscribe(ctx, peerID)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already waiting")
|
||||
}
|
||||
|
||||
func TestUnsubscribeStateChange(t *testing.T) {
|
||||
peerID := messages.HashID("peer4")
|
||||
mockConn := &mockRelayedConn{}
|
||||
logger := logrus.New()
|
||||
logger.SetOutput(&bytes.Buffer{})
|
||||
sub := NewPeersStateSubscription(logrus.NewEntry(logger), mockConn, nil)
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
go func() {
|
||||
_ = sub.WaitToBeOnlineAndSubscribe(context.Background(), peerID)
|
||||
close(doneChan)
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
err := sub.UnsubscribeStateChange([]messages.PeerID{peerID})
|
||||
assert.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-doneChan:
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
// Expected timeout, meaning the subscription was successfully unsubscribed
|
||||
t.Errorf("timeout")
|
||||
}
|
||||
}
|
||||
104
shared/relay/client/picker.go
Normal file
104
shared/relay/client/picker.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
)
|
||||
|
||||
const (
|
||||
maxConcurrentServers = 7
|
||||
)
|
||||
|
||||
var (
|
||||
connectionTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type connResult struct {
|
||||
RelayClient *Client
|
||||
Url string
|
||||
Err error
|
||||
}
|
||||
|
||||
type ServerPicker struct {
|
||||
TokenStore *auth.TokenStore
|
||||
ServerURLs atomic.Value
|
||||
PeerID string
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
totalServers := len(sp.ServerURLs.Load().([]string))
|
||||
|
||||
connResultChan := make(chan connResult, totalServers)
|
||||
successChan := make(chan connResult, 1)
|
||||
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
|
||||
|
||||
log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string))
|
||||
for _, url := range sp.ServerURLs.Load().([]string) {
|
||||
// todo check if we have a successful connection so we do not need to connect to other servers
|
||||
concurrentLimiter <- struct{}{}
|
||||
go func(url string) {
|
||||
defer func() {
|
||||
<-concurrentLimiter
|
||||
}()
|
||||
sp.startConnection(parentCtx, connResultChan, url)
|
||||
}(url)
|
||||
}
|
||||
|
||||
go sp.processConnResults(connResultChan, successChan)
|
||||
|
||||
select {
|
||||
case cr, ok := <-successChan:
|
||||
if !ok {
|
||||
return nil, errors.New("failed to connect to any relay server: all attempts failed")
|
||||
}
|
||||
log.Infof("chosen home Relay server: %s", cr.Url)
|
||||
return cr.RelayClient, nil
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
|
||||
log.Infof("try to connecting to relay server: %s", url)
|
||||
relayClient := NewClient(url, sp.TokenStore, sp.PeerID)
|
||||
err := relayClient.Connect(ctx)
|
||||
resultChan <- connResult{
|
||||
RelayClient: relayClient,
|
||||
Url: url,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
|
||||
var hasSuccess bool
|
||||
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
|
||||
cr := <-resultChan
|
||||
if cr.Err != nil {
|
||||
log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
|
||||
continue
|
||||
}
|
||||
log.Infof("connected to Relay server: %s", cr.Url)
|
||||
|
||||
if hasSuccess {
|
||||
log.Infof("closing unnecessary Relay connection to: %s", cr.Url)
|
||||
if err := cr.RelayClient.Close(); err != nil {
|
||||
log.Errorf("failed to close connection to %s: %v", cr.Url, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
hasSuccess = true
|
||||
successChan <- cr
|
||||
}
|
||||
close(successChan)
|
||||
}
|
||||
34
shared/relay/client/picker_test.go
Normal file
34
shared/relay/client/picker_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestServerPicker_UnavailableServers(t *testing.T) {
|
||||
connectionTimeout = 5 * time.Second
|
||||
|
||||
sp := ServerPicker{
|
||||
TokenStore: nil,
|
||||
PeerID: "test",
|
||||
}
|
||||
sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
_, err := sp.PickServer(ctx)
|
||||
if err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
t.Errorf("PickServer() took too long to complete")
|
||||
}
|
||||
}
|
||||
77
shared/signal/client/client.go
Normal file
77
shared/signal/client/client.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer.
|
||||
|
||||
// Status is the status of the client
|
||||
type Status string
|
||||
|
||||
const StreamConnected Status = "Connected"
|
||||
const StreamDisconnected Status = "Disconnected"
|
||||
|
||||
const (
|
||||
// DirectCheck indicates support to direct mode checks
|
||||
DirectCheck uint32 = 1
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
io.Closer
|
||||
StreamConnected() bool
|
||||
GetStatus() Status
|
||||
Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error
|
||||
Ready() bool
|
||||
IsHealthy() bool
|
||||
WaitStreamConnected()
|
||||
SendToStream(msg *proto.EncryptedMessage) error
|
||||
Send(msg *proto.Message) error
|
||||
SetOnReconnectedListener(func())
|
||||
}
|
||||
|
||||
// UnMarshalCredential parses the credentials from the message and returns a Credential instance
|
||||
func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
|
||||
|
||||
credential := strings.Split(msg.GetBody().GetPayload(), ":")
|
||||
if len(credential) != 2 {
|
||||
return nil, fmt.Errorf("error parsing message body %s", msg.Body)
|
||||
}
|
||||
return &Credential{
|
||||
UFrag: credential[0],
|
||||
Pwd: credential[1],
|
||||
}, nil
|
||||
}
|
||||
|
||||
// MarshalCredential marshal a Credential instance and returns a Message object
|
||||
func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string) (*proto.Message, error) {
|
||||
return &proto.Message{
|
||||
Key: myKey.PublicKey().String(),
|
||||
RemoteKey: remoteKey,
|
||||
Body: &proto.Body{
|
||||
Type: t,
|
||||
Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd),
|
||||
WgListenPort: uint32(myPort),
|
||||
NetBirdVersion: version.NetbirdVersion(),
|
||||
RosenpassConfig: &proto.RosenpassConfig{
|
||||
RosenpassPubKey: rosenpassPubKey,
|
||||
RosenpassServerAddr: rosenpassAddr,
|
||||
},
|
||||
RelayServerAddress: relaySrvAddress,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Credential is an instance of a GrpcClient's Credential
|
||||
type Credential struct {
|
||||
UFrag string
|
||||
Pwd string
|
||||
}
|
||||
13
shared/signal/client/client_suite_test.go
Normal file
13
shared/signal/client/client_suite_test.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSignal(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Signal Suite")
|
||||
}
|
||||
230
shared/signal/client/client_test.go
Normal file
230
shared/signal/client/client_test.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
sigProto "github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/server"
|
||||
)
|
||||
|
||||
var _ = Describe("GrpcClient", func() {
|
||||
|
||||
var (
|
||||
addr string
|
||||
listener net.Listener
|
||||
server *grpc.Server
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
server, listener = startSignal()
|
||||
addr = listener.Addr().String()
|
||||
|
||||
})
|
||||
|
||||
AfterEach(func() {
|
||||
server.Stop()
|
||||
listener.Close()
|
||||
})
|
||||
|
||||
Describe("Exchanging messages", func() {
|
||||
Context("between connected peers", func() {
|
||||
It("should be successful", func() {
|
||||
|
||||
var msgReceived sync.WaitGroup
|
||||
msgReceived.Add(2)
|
||||
|
||||
var payloadReceivedOnA string
|
||||
var payloadReceivedOnB string
|
||||
var featuresSupportedReceivedOnA []uint32
|
||||
var featuresSupportedReceivedOnB []uint32
|
||||
|
||||
// connect PeerA to Signal
|
||||
keyA, _ := wgtypes.GenerateKey()
|
||||
clientA := createSignalClient(addr, keyA)
|
||||
go func() {
|
||||
err := clientA.Receive(context.Background(), func(msg *sigProto.Message) error {
|
||||
payloadReceivedOnA = msg.GetBody().GetPayload()
|
||||
featuresSupportedReceivedOnA = msg.GetBody().GetFeaturesSupported()
|
||||
msgReceived.Done()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
clientA.WaitStreamConnected()
|
||||
|
||||
// connect PeerB to Signal
|
||||
keyB, _ := wgtypes.GenerateKey()
|
||||
clientB := createSignalClient(addr, keyB)
|
||||
|
||||
go func() {
|
||||
err := clientB.Receive(context.Background(), func(msg *sigProto.Message) error {
|
||||
payloadReceivedOnB = msg.GetBody().GetPayload()
|
||||
featuresSupportedReceivedOnB = msg.GetBody().GetFeaturesSupported()
|
||||
err := clientB.Send(&sigProto.Message{
|
||||
Key: keyB.PublicKey().String(),
|
||||
RemoteKey: keyA.PublicKey().String(),
|
||||
Body: &sigProto.Body{Payload: "pong"},
|
||||
})
|
||||
if err != nil {
|
||||
Fail("failed sending a message to PeerA")
|
||||
}
|
||||
msgReceived.Done()
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
clientB.WaitStreamConnected()
|
||||
|
||||
// PeerA initiates ping-pong
|
||||
err := clientA.Send(&sigProto.Message{
|
||||
Key: keyA.PublicKey().String(),
|
||||
RemoteKey: keyB.PublicKey().String(),
|
||||
Body: &sigProto.Body{Payload: "ping", FeaturesSupported: []uint32{DirectCheck}},
|
||||
})
|
||||
if err != nil {
|
||||
Fail("failed sending a message to PeerB")
|
||||
}
|
||||
|
||||
if waitTimeout(&msgReceived, 3*time.Second) {
|
||||
Fail("test timed out on waiting for peers to exchange messages")
|
||||
}
|
||||
|
||||
Expect(payloadReceivedOnA).To(BeEquivalentTo("pong"))
|
||||
Expect(payloadReceivedOnB).To(BeEquivalentTo("ping"))
|
||||
Expect(featuresSupportedReceivedOnA).To(BeNil())
|
||||
Expect(featuresSupportedReceivedOnB).To(ContainElements([]uint32{DirectCheck}))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
Describe("Connecting to the Signal stream channel", func() {
|
||||
Context("with a signal client", func() {
|
||||
It("should be successful", func() {
|
||||
|
||||
key, _ := wgtypes.GenerateKey()
|
||||
client := createSignalClient(addr, key)
|
||||
go func() {
|
||||
err := client.Receive(context.Background(), func(msg *sigProto.Message) error {
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}()
|
||||
client.WaitStreamConnected()
|
||||
Expect(client).NotTo(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with a raw client and no Id header", func() {
|
||||
It("should fail", func() {
|
||||
|
||||
client := createRawSignalClient(addr)
|
||||
stream, err := client.ConnectStream(context.Background())
|
||||
if err != nil {
|
||||
Fail("error connecting to stream")
|
||||
}
|
||||
|
||||
_, err = stream.Recv()
|
||||
|
||||
Expect(stream).NotTo(BeNil())
|
||||
Expect(err).NotTo(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
Context("with a raw client and with an Id header", func() {
|
||||
It("should be successful", func() {
|
||||
|
||||
md := metadata.New(map[string]string{sigProto.HeaderId: "peer"})
|
||||
ctx := metadata.NewOutgoingContext(context.Background(), md)
|
||||
|
||||
client := createRawSignalClient(addr)
|
||||
stream, err := client.ConnectStream(ctx)
|
||||
|
||||
Expect(stream).NotTo(BeNil())
|
||||
Expect(err).To(BeNil())
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
func createSignalClient(addr string, key wgtypes.Key) *GrpcClient {
|
||||
var sigTLSEnabled = false
|
||||
client, err := NewClient(context.Background(), addr, key, sigTLSEnabled)
|
||||
if err != nil {
|
||||
Fail("failed creating signal client")
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func createRawSignalClient(addr string) sigProto.SignalExchangeClient {
|
||||
ctx := context.Background()
|
||||
conn, err := grpc.DialContext(ctx, addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 3 * time.Second,
|
||||
Timeout: 2 * time.Second,
|
||||
}))
|
||||
if err != nil {
|
||||
Fail("failed creating raw signal client")
|
||||
}
|
||||
|
||||
return sigProto.NewSignalExchangeClient(conn)
|
||||
}
|
||||
|
||||
func startSignal() (*grpc.Server, net.Listener) {
|
||||
lis, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s := grpc.NewServer()
|
||||
srv, err := server.NewServer(context.Background(), otel.Meter(""))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sigProto.RegisterSignalExchangeServer(s, srv)
|
||||
go func() {
|
||||
if err := s.Serve(lis); err != nil {
|
||||
log.Fatalf("failed to serve: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s, lis
|
||||
}
|
||||
|
||||
// waitTimeout waits for the waitgroup for the specified max timeout.
|
||||
// Returns true if waiting timed out.
|
||||
func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
|
||||
c := make(chan struct{})
|
||||
go func() {
|
||||
defer close(c)
|
||||
wg.Wait()
|
||||
}()
|
||||
select {
|
||||
case <-c:
|
||||
return false // completed normally
|
||||
case <-time.After(timeout):
|
||||
return true // timed out
|
||||
}
|
||||
}
|
||||
10
shared/signal/client/go.sum
Normal file
10
shared/signal/client/go.sum
Normal file
@@ -0,0 +1,10 @@
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
437
shared/signal/client/grpc.go
Normal file
437
shared/signal/client/grpc.go
Normal file
@@ -0,0 +1,437 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/encryption"
|
||||
"github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
nbgrpc "github.com/netbirdio/netbird/util/grpc"
|
||||
)
|
||||
|
||||
// ConnStateNotifier is a wrapper interface of the status recorder
|
||||
type ConnStateNotifier interface {
|
||||
MarkSignalDisconnected(error)
|
||||
MarkSignalConnected()
|
||||
}
|
||||
|
||||
// GrpcClient Wraps the Signal Exchange Service gRpc client
|
||||
type GrpcClient struct {
|
||||
key wgtypes.Key
|
||||
realClient proto.SignalExchangeClient
|
||||
signalConn *grpc.ClientConn
|
||||
ctx context.Context
|
||||
stream proto.SignalExchange_ConnectStreamClient
|
||||
// connectedCh used to notify goroutines waiting for the connection to the Signal stream
|
||||
connectedCh chan struct{}
|
||||
mux sync.Mutex
|
||||
// StreamConnected indicates whether this client is StreamConnected to the Signal stream
|
||||
status Status
|
||||
|
||||
connStateCallback ConnStateNotifier
|
||||
connStateCallbackLock sync.RWMutex
|
||||
|
||||
onReconnectedListenerFn func()
|
||||
}
|
||||
|
||||
func (c *GrpcClient) StreamConnected() bool {
|
||||
return c.status == StreamConnected
|
||||
}
|
||||
|
||||
func (c *GrpcClient) GetStatus() Status {
|
||||
return c.status
|
||||
}
|
||||
|
||||
// Close Closes underlying connections to the Signal Exchange
|
||||
func (c *GrpcClient) Close() error {
|
||||
return c.signalConn.Close()
|
||||
}
|
||||
|
||||
// NewClient creates a new Signal client
|
||||
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
|
||||
var conn *grpc.ClientConn
|
||||
|
||||
operation := func() error {
|
||||
var err error
|
||||
conn, err = nbgrpc.CreateConnection(addr, tlsEnabled)
|
||||
if err != nil {
|
||||
log.Printf("createConnection error: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, nbgrpc.Backoff(ctx))
|
||||
if err != nil {
|
||||
log.Errorf("failed to connect to the signalling server: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("connected to Signal Service: %v", conn.Target())
|
||||
|
||||
return &GrpcClient{
|
||||
realClient: proto.NewSignalExchangeClient(conn),
|
||||
ctx: ctx,
|
||||
signalConn: conn,
|
||||
key: key,
|
||||
mux: sync.Mutex{},
|
||||
status: StreamDisconnected,
|
||||
connStateCallbackLock: sync.RWMutex{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetConnStateListener set the ConnStateNotifier
|
||||
func (c *GrpcClient) SetConnStateListener(notifier ConnStateNotifier) {
|
||||
c.connStateCallbackLock.Lock()
|
||||
defer c.connStateCallbackLock.Unlock()
|
||||
c.connStateCallback = notifier
|
||||
}
|
||||
|
||||
// defaultBackoff is a basic backoff mechanism for general issues
|
||||
func defaultBackoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(&backoff.ExponentialBackOff{
|
||||
InitialInterval: 800 * time.Millisecond,
|
||||
RandomizationFactor: 1,
|
||||
Multiplier: 1.7,
|
||||
MaxInterval: 10 * time.Second,
|
||||
MaxElapsedTime: 3 * 30 * 24 * time.Hour, // 3 months
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
// Receive Connects to the Signal Exchange message stream and starts receiving messages.
|
||||
// The messages will be handled by msgHandler function provided.
|
||||
// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
|
||||
// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
|
||||
func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error {
|
||||
|
||||
var backOff = defaultBackoff(ctx)
|
||||
|
||||
operation := func() error {
|
||||
|
||||
c.notifyStreamDisconnected()
|
||||
|
||||
log.Debugf("signal connection state %v", c.signalConn.GetState())
|
||||
connState := c.signalConn.GetState()
|
||||
if connState == connectivity.Shutdown {
|
||||
return backoff.Permanent(fmt.Errorf("connection to signal has been shut down"))
|
||||
} else if !(connState == connectivity.Ready || connState == connectivity.Idle) {
|
||||
c.signalConn.WaitForStateChange(ctx, connState)
|
||||
return fmt.Errorf("connection to signal is not ready and in %s state", connState)
|
||||
}
|
||||
|
||||
// connect to Signal stream identifying ourselves with a public WireGuard key
|
||||
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
|
||||
ctx, cancelStream := context.WithCancel(ctx)
|
||||
defer cancelStream()
|
||||
stream, err := c.connect(ctx, c.key.PublicKey().String())
|
||||
if err != nil {
|
||||
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
c.notifyStreamConnected()
|
||||
|
||||
log.Infof("connected to the Signal Service stream")
|
||||
c.notifyConnected()
|
||||
// start receiving messages from the Signal stream (from other peers through signal)
|
||||
err = c.receive(stream, msgHandler)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
log.Debugf("signal connection context has been canceled, this usually indicates shutdown")
|
||||
return nil
|
||||
}
|
||||
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
|
||||
// reset times and next try will start with a long delay
|
||||
backOff.Reset()
|
||||
c.notifyDisconnected(err)
|
||||
log.Warnf("disconnected from the Signal service but will retry silently. Reason: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, backOff)
|
||||
if err != nil {
|
||||
log.Errorf("exiting the Signal service connection retry loop due to the unrecoverable error: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func (c *GrpcClient) notifyStreamDisconnected() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
c.status = StreamDisconnected
|
||||
}
|
||||
|
||||
func (c *GrpcClient) notifyStreamConnected() {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
c.status = StreamConnected
|
||||
if c.connectedCh != nil {
|
||||
// there are goroutines waiting on this channel -> release them
|
||||
close(c.connectedCh)
|
||||
c.connectedCh = nil
|
||||
}
|
||||
|
||||
if c.onReconnectedListenerFn != nil {
|
||||
c.onReconnectedListenerFn()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
if c.connectedCh == nil {
|
||||
c.connectedCh = make(chan struct{})
|
||||
}
|
||||
return c.connectedCh
|
||||
}
|
||||
|
||||
func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
|
||||
c.stream = nil
|
||||
|
||||
// add key fingerprint to the request header to be identified on the server side
|
||||
md := metadata.New(map[string]string{proto.HeaderId: key})
|
||||
metaCtx := metadata.NewOutgoingContext(ctx, md)
|
||||
stream, err := c.realClient.ConnectStream(metaCtx, grpc.WaitForReady(true))
|
||||
c.stream = stream
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// blocks
|
||||
header, err := c.stream.Header()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
registered := header.Get(proto.HeaderRegistered)
|
||||
if len(registered) == 0 {
|
||||
return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams")
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// Ready indicates whether the client is okay and Ready to be used
|
||||
// for now it just checks whether gRPC connection to the service is in state Ready
|
||||
func (c *GrpcClient) Ready() bool {
|
||||
return c.signalConn.GetState() == connectivity.Ready || c.signalConn.GetState() == connectivity.Idle
|
||||
}
|
||||
|
||||
// IsHealthy probes the gRPC connection and returns false on errors
|
||||
func (c *GrpcClient) IsHealthy() bool {
|
||||
switch c.signalConn.GetState() {
|
||||
case connectivity.TransientFailure:
|
||||
return false
|
||||
case connectivity.Connecting:
|
||||
return true
|
||||
case connectivity.Shutdown:
|
||||
return true
|
||||
case connectivity.Idle:
|
||||
case connectivity.Ready:
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second)
|
||||
defer cancel()
|
||||
_, err := c.realClient.Send(ctx, &proto.EncryptedMessage{
|
||||
Key: c.key.PublicKey().String(),
|
||||
RemoteKey: "dummy",
|
||||
Body: nil,
|
||||
})
|
||||
if err != nil {
|
||||
c.notifyDisconnected(err)
|
||||
log.Warnf("health check returned: %s", err)
|
||||
return false
|
||||
}
|
||||
c.notifyConnected()
|
||||
return true
|
||||
}
|
||||
|
||||
// WaitStreamConnected waits until the client is connected to the Signal stream
|
||||
func (c *GrpcClient) WaitStreamConnected() {
|
||||
|
||||
if c.status == StreamConnected {
|
||||
return
|
||||
}
|
||||
|
||||
ch := c.getStreamStatusChan()
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
case <-ch:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClient) SetOnReconnectedListener(fn func()) {
|
||||
c.mux.Lock()
|
||||
defer c.mux.Unlock()
|
||||
|
||||
c.onReconnectedListenerFn = fn
|
||||
}
|
||||
|
||||
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
|
||||
// The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
|
||||
// GrpcClient.connWg can be used to wait
|
||||
func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error {
|
||||
if !c.Ready() {
|
||||
return fmt.Errorf("no connection to signal")
|
||||
}
|
||||
if c.stream == nil {
|
||||
return fmt.Errorf("connection to the Signal Exchange has not been established yet. Please call GrpcClient.Receive before sending messages")
|
||||
}
|
||||
|
||||
err := c.stream.Send(msg)
|
||||
if err != nil {
|
||||
log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key
|
||||
func (c *GrpcClient) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, error) {
|
||||
remoteKey, err := wgtypes.ParseKey(msg.GetKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body := &proto.Body{}
|
||||
err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &proto.Message{
|
||||
Key: msg.Key,
|
||||
RemoteKey: msg.RemoteKey,
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key
|
||||
func (c *GrpcClient) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) {
|
||||
|
||||
remoteKey, err := wgtypes.ParseKey(msg.RemoteKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
Key: msg.GetKey(),
|
||||
RemoteKey: msg.GetRemoteKey(),
|
||||
Body: encryptedBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Send sends a message to the remote Peer through the Signal Exchange.
|
||||
func (c *GrpcClient) Send(msg *proto.Message) error {
|
||||
|
||||
if !c.Ready() {
|
||||
return fmt.Errorf("no connection to signal")
|
||||
}
|
||||
|
||||
encryptedMessage, err := c.encryptMessage(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
attemptTimeout := client.ConnectTimeout
|
||||
|
||||
for attempt := 0; attempt < 4; attempt++ {
|
||||
if attempt > 1 {
|
||||
attemptTimeout = time.Duration(attempt) * 5 * time.Second
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(c.ctx, attemptTimeout)
|
||||
|
||||
_, err = c.realClient.Send(ctx, encryptedMessage)
|
||||
|
||||
cancel()
|
||||
|
||||
if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
return err
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// receive receives messages from other peers coming through the Signal Exchange
|
||||
func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
|
||||
msgHandler func(msg *proto.Message) error) error {
|
||||
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
switch s, ok := status.FromError(err); {
|
||||
case ok && s.Code() == codes.Canceled:
|
||||
log.Debugf("stream canceled (usually indicates shutdown)")
|
||||
return err
|
||||
case s.Code() == codes.Unavailable:
|
||||
log.Debugf("Signal Service is unavailable")
|
||||
return err
|
||||
case err == io.EOF:
|
||||
log.Debugf("Signal Service stream closed by server")
|
||||
return err
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
log.Tracef("received a new message from Peer [fingerprint: %s]", msg.Key)
|
||||
|
||||
decryptedMessage, err := c.decryptMessage(msg)
|
||||
if err != nil {
|
||||
log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
|
||||
}
|
||||
|
||||
err = msgHandler(decryptedMessage)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error())
|
||||
// todo send something??
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GrpcClient) notifyDisconnected(err error) {
|
||||
c.connStateCallbackLock.RLock()
|
||||
defer c.connStateCallbackLock.RUnlock()
|
||||
|
||||
if c.connStateCallback == nil {
|
||||
return
|
||||
}
|
||||
c.connStateCallback.MarkSignalDisconnected(err)
|
||||
}
|
||||
|
||||
func (c *GrpcClient) notifyConnected() {
|
||||
c.connStateCallbackLock.RLock()
|
||||
defer c.connStateCallbackLock.RUnlock()
|
||||
|
||||
if c.connStateCallback == nil {
|
||||
return
|
||||
}
|
||||
c.connStateCallback.MarkSignalConnected()
|
||||
}
|
||||
84
shared/signal/client/mock.go
Normal file
84
shared/signal/client/mock.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
)
|
||||
|
||||
type MockClient struct {
|
||||
CloseFunc func() error
|
||||
GetStatusFunc func() Status
|
||||
StreamConnectedFunc func() bool
|
||||
ReadyFunc func() bool
|
||||
WaitStreamConnectedFunc func()
|
||||
ReceiveFunc func(ctx context.Context, msgHandler func(msg *proto.Message) error) error
|
||||
SendToStreamFunc func(msg *proto.EncryptedMessage) error
|
||||
SendFunc func(msg *proto.Message) error
|
||||
SetOnReconnectedListenerFunc func(f func())
|
||||
}
|
||||
|
||||
// SetOnReconnectedListener sets the function to be called when the client reconnects.
|
||||
func (sm *MockClient) SetOnReconnectedListener(_ func()) {
|
||||
// Do nothing
|
||||
}
|
||||
|
||||
func (sm *MockClient) IsHealthy() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (sm *MockClient) Close() error {
|
||||
if sm.CloseFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return sm.CloseFunc()
|
||||
}
|
||||
|
||||
func (sm *MockClient) GetStatus() Status {
|
||||
if sm.GetStatusFunc == nil {
|
||||
return ""
|
||||
}
|
||||
return sm.GetStatusFunc()
|
||||
}
|
||||
|
||||
func (sm *MockClient) StreamConnected() bool {
|
||||
if sm.StreamConnectedFunc == nil {
|
||||
return false
|
||||
}
|
||||
return sm.StreamConnectedFunc()
|
||||
}
|
||||
|
||||
func (sm *MockClient) Ready() bool {
|
||||
if sm.ReadyFunc == nil {
|
||||
return false
|
||||
}
|
||||
return sm.ReadyFunc()
|
||||
}
|
||||
|
||||
func (sm *MockClient) WaitStreamConnected() {
|
||||
if sm.WaitStreamConnectedFunc == nil {
|
||||
return
|
||||
}
|
||||
sm.WaitStreamConnectedFunc()
|
||||
}
|
||||
|
||||
func (sm *MockClient) Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error {
|
||||
if sm.ReceiveFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return sm.ReceiveFunc(ctx, msgHandler)
|
||||
}
|
||||
|
||||
func (sm *MockClient) SendToStream(msg *proto.EncryptedMessage) error {
|
||||
if sm.SendToStreamFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return sm.SendToStreamFunc(msg)
|
||||
}
|
||||
|
||||
func (sm *MockClient) Send(msg *proto.Message) error {
|
||||
if sm.SendFunc == nil {
|
||||
return nil
|
||||
}
|
||||
return sm.SendFunc(msg)
|
||||
}
|
||||
5
shared/signal/proto/constants.go
Normal file
5
shared/signal/proto/constants.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package proto
|
||||
|
||||
// protocol constants, field names that can be used by both client and server
|
||||
const HeaderId = "x-wiretrustee-peer-id"
|
||||
const HeaderRegistered = "x-wiretrustee-peer-registered"
|
||||
17
shared/signal/proto/generate.sh
Executable file
17
shared/signal/proto/generate.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if ! which realpath > /dev/null 2>&1
|
||||
then
|
||||
echo realpath is not installed
|
||||
echo run: brew install coreutils
|
||||
exit 1
|
||||
fi
|
||||
|
||||
old_pwd=$(pwd)
|
||||
script_path=$(dirname $(realpath "$0"))
|
||||
cd "$script_path"
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
||||
protoc -I ./ ./signalexchange.proto --go_out=../ --go-grpc_out=../
|
||||
cd "$old_pwd"
|
||||
2
shared/signal/proto/go.sum
Normal file
2
shared/signal/proto/go.sum
Normal file
@@ -0,0 +1,2 @@
|
||||
google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0=
|
||||
google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY=
|
||||
624
shared/signal/proto/signalexchange.pb.go
Normal file
624
shared/signal/proto/signalexchange.pb.go
Normal file
@@ -0,0 +1,624 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.26.0
|
||||
// protoc v3.21.12
|
||||
// source: signalexchange.proto
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
_ "google.golang.org/protobuf/types/descriptorpb"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// Message type
|
||||
type Body_Type int32
|
||||
|
||||
const (
|
||||
Body_OFFER Body_Type = 0
|
||||
Body_ANSWER Body_Type = 1
|
||||
Body_CANDIDATE Body_Type = 2
|
||||
Body_MODE Body_Type = 4
|
||||
Body_GO_IDLE Body_Type = 5
|
||||
)
|
||||
|
||||
// Enum value maps for Body_Type.
|
||||
var (
|
||||
Body_Type_name = map[int32]string{
|
||||
0: "OFFER",
|
||||
1: "ANSWER",
|
||||
2: "CANDIDATE",
|
||||
4: "MODE",
|
||||
5: "GO_IDLE",
|
||||
}
|
||||
Body_Type_value = map[string]int32{
|
||||
"OFFER": 0,
|
||||
"ANSWER": 1,
|
||||
"CANDIDATE": 2,
|
||||
"MODE": 4,
|
||||
"GO_IDLE": 5,
|
||||
}
|
||||
)
|
||||
|
||||
func (x Body_Type) Enum() *Body_Type {
|
||||
p := new(Body_Type)
|
||||
*p = x
|
||||
return p
|
||||
}
|
||||
|
||||
func (x Body_Type) String() string {
|
||||
return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x))
|
||||
}
|
||||
|
||||
func (Body_Type) Descriptor() protoreflect.EnumDescriptor {
|
||||
return file_signalexchange_proto_enumTypes[0].Descriptor()
|
||||
}
|
||||
|
||||
func (Body_Type) Type() protoreflect.EnumType {
|
||||
return &file_signalexchange_proto_enumTypes[0]
|
||||
}
|
||||
|
||||
func (x Body_Type) Number() protoreflect.EnumNumber {
|
||||
return protoreflect.EnumNumber(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Body_Type.Descriptor instead.
|
||||
func (Body_Type) EnumDescriptor() ([]byte, []int) {
|
||||
return file_signalexchange_proto_rawDescGZIP(), []int{2, 0}
|
||||
}
|
||||
|
||||
// Used for sending through signal.
|
||||
// The body of this message is the Body message encrypted with the Wireguard private key and the remote Peer key
|
||||
type EncryptedMessage struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Wireguard public key
|
||||
Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"`
|
||||
// Wireguard public key of the remote peer to connect to
|
||||
RemoteKey string `protobuf:"bytes,3,opt,name=remoteKey,proto3" json:"remoteKey,omitempty"`
|
||||
// encrypted message Body
|
||||
Body []byte `protobuf:"bytes,4,opt,name=body,proto3" json:"body,omitempty"`
|
||||
}
|
||||
|
||||
func (x *EncryptedMessage) Reset() {
|
||||
*x = EncryptedMessage{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_signalexchange_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *EncryptedMessage) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*EncryptedMessage) ProtoMessage() {}
|
||||
|
||||
func (x *EncryptedMessage) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_signalexchange_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use EncryptedMessage.ProtoReflect.Descriptor instead.
|
||||
func (*EncryptedMessage) Descriptor() ([]byte, []int) {
|
||||
return file_signalexchange_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *EncryptedMessage) GetKey() string {
|
||||
if x != nil {
|
||||
return x.Key
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *EncryptedMessage) GetRemoteKey() string {
|
||||
if x != nil {
|
||||
return x.RemoteKey
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *EncryptedMessage) GetBody() []byte {
|
||||
if x != nil {
|
||||
return x.Body
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// A decrypted representation of the EncryptedMessage. Used locally before/after encryption
|
||||
type Message struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// WireGuard public key
|
||||
Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"`
|
||||
// WireGuard public key of the remote peer to connect to
|
||||
RemoteKey string `protobuf:"bytes,3,opt,name=remoteKey,proto3" json:"remoteKey,omitempty"`
|
||||
Body *Body `protobuf:"bytes,4,opt,name=body,proto3" json:"body,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Message) Reset() {
|
||||
*x = Message{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_signalexchange_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *Message) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Message) ProtoMessage() {}
|
||||
|
||||
func (x *Message) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_signalexchange_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Message.ProtoReflect.Descriptor instead.
|
||||
func (*Message) Descriptor() ([]byte, []int) {
|
||||
return file_signalexchange_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *Message) GetKey() string {
|
||||
if x != nil {
|
||||
return x.Key
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Message) GetRemoteKey() string {
|
||||
if x != nil {
|
||||
return x.RemoteKey
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Message) GetBody() *Body {
|
||||
if x != nil {
|
||||
return x.Body
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Actual body of the message that can contain credentials (type OFFER/ANSWER) or connection Candidate
|
||||
// This part will be encrypted
|
||||
type Body struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Type Body_Type `protobuf:"varint,1,opt,name=type,proto3,enum=signalexchange.Body_Type" json:"type,omitempty"`
|
||||
Payload string `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
|
||||
// wgListenPort is an actual WireGuard listen port
|
||||
WgListenPort uint32 `protobuf:"varint,3,opt,name=wgListenPort,proto3" json:"wgListenPort,omitempty"`
|
||||
NetBirdVersion string `protobuf:"bytes,4,opt,name=netBirdVersion,proto3" json:"netBirdVersion,omitempty"`
|
||||
Mode *Mode `protobuf:"bytes,5,opt,name=mode,proto3" json:"mode,omitempty"`
|
||||
// featuresSupported list of supported features by the client of this protocol
|
||||
FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"`
|
||||
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
|
||||
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
|
||||
// relayServerAddress is url of the relay server
|
||||
RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Body) Reset() {
|
||||
*x = Body{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_signalexchange_proto_msgTypes[2]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *Body) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Body) ProtoMessage() {}
|
||||
|
||||
func (x *Body) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_signalexchange_proto_msgTypes[2]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Body.ProtoReflect.Descriptor instead.
|
||||
func (*Body) Descriptor() ([]byte, []int) {
|
||||
return file_signalexchange_proto_rawDescGZIP(), []int{2}
|
||||
}
|
||||
|
||||
func (x *Body) GetType() Body_Type {
|
||||
if x != nil {
|
||||
return x.Type
|
||||
}
|
||||
return Body_OFFER
|
||||
}
|
||||
|
||||
func (x *Body) GetPayload() string {
|
||||
if x != nil {
|
||||
return x.Payload
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Body) GetWgListenPort() uint32 {
|
||||
if x != nil {
|
||||
return x.WgListenPort
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *Body) GetNetBirdVersion() string {
|
||||
if x != nil {
|
||||
return x.NetBirdVersion
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Body) GetMode() *Mode {
|
||||
if x != nil {
|
||||
return x.Mode
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Body) GetFeaturesSupported() []uint32 {
|
||||
if x != nil {
|
||||
return x.FeaturesSupported
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Body) GetRosenpassConfig() *RosenpassConfig {
|
||||
if x != nil {
|
||||
return x.RosenpassConfig
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Body) GetRelayServerAddress() string {
|
||||
if x != nil {
|
||||
return x.RelayServerAddress
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Mode indicates a connection mode
|
||||
type Mode struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
Direct *bool `protobuf:"varint,1,opt,name=direct,proto3,oneof" json:"direct,omitempty"`
|
||||
}
|
||||
|
||||
func (x *Mode) Reset() {
|
||||
*x = Mode{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_signalexchange_proto_msgTypes[3]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *Mode) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Mode) ProtoMessage() {}
|
||||
|
||||
func (x *Mode) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_signalexchange_proto_msgTypes[3]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Mode.ProtoReflect.Descriptor instead.
|
||||
func (*Mode) Descriptor() ([]byte, []int) {
|
||||
return file_signalexchange_proto_rawDescGZIP(), []int{3}
|
||||
}
|
||||
|
||||
func (x *Mode) GetDirect() bool {
|
||||
if x != nil && x.Direct != nil {
|
||||
return *x.Direct
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type RosenpassConfig struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
RosenpassPubKey []byte `protobuf:"bytes,1,opt,name=rosenpassPubKey,proto3" json:"rosenpassPubKey,omitempty"`
|
||||
// rosenpassServerAddr is an IP:port of the rosenpass service
|
||||
RosenpassServerAddr string `protobuf:"bytes,2,opt,name=rosenpassServerAddr,proto3" json:"rosenpassServerAddr,omitempty"`
|
||||
}
|
||||
|
||||
func (x *RosenpassConfig) Reset() {
|
||||
*x = RosenpassConfig{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_signalexchange_proto_msgTypes[4]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *RosenpassConfig) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*RosenpassConfig) ProtoMessage() {}
|
||||
|
||||
func (x *RosenpassConfig) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_signalexchange_proto_msgTypes[4]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use RosenpassConfig.ProtoReflect.Descriptor instead.
|
||||
func (*RosenpassConfig) Descriptor() ([]byte, []int) {
|
||||
return file_signalexchange_proto_rawDescGZIP(), []int{4}
|
||||
}
|
||||
|
||||
func (x *RosenpassConfig) GetRosenpassPubKey() []byte {
|
||||
if x != nil {
|
||||
return x.RosenpassPubKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *RosenpassConfig) GetRosenpassServerAddr() string {
|
||||
if x != nil {
|
||||
return x.RosenpassServerAddr
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_signalexchange_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_signalexchange_proto_rawDesc = []byte{
|
||||
0x0a, 0x14, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65,
|
||||
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
|
||||
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x20, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74,
|
||||
0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x56, 0x0a, 0x10, 0x45, 0x6e, 0x63, 0x72,
|
||||
0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x10, 0x0a, 0x03,
|
||||
0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x1c,
|
||||
0x0a, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28,
|
||||
0x09, 0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04,
|
||||
0x62, 0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x62, 0x6f, 0x64, 0x79,
|
||||
0x22, 0x63, 0x0a, 0x07, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6b,
|
||||
0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x1c, 0x0a,
|
||||
0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
|
||||
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
|
||||
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
|
||||
0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xb3, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
|
||||
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
|
||||
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
|
||||
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
|
||||
0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
|
||||
0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x12, 0x22, 0x0a, 0x0c, 0x77, 0x67, 0x4c, 0x69, 0x73,
|
||||
0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0c, 0x77,
|
||||
0x67, 0x4c, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x26, 0x0a, 0x0e, 0x6e,
|
||||
0x65, 0x74, 0x42, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20,
|
||||
0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x74, 0x42, 0x69, 0x72, 0x64, 0x56, 0x65, 0x72, 0x73,
|
||||
0x69, 0x6f, 0x6e, 0x12, 0x28, 0x0a, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28,
|
||||
0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e,
|
||||
0x67, 0x65, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x52, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x12, 0x2c, 0x0a,
|
||||
0x11, 0x66, 0x65, 0x61, 0x74, 0x75, 0x72, 0x65, 0x73, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74,
|
||||
0x65, 0x64, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0d, 0x52, 0x11, 0x66, 0x65, 0x61, 0x74, 0x75, 0x72,
|
||||
0x65, 0x73, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x65, 0x64, 0x12, 0x49, 0x0a, 0x0f, 0x72,
|
||||
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x07,
|
||||
0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
|
||||
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43,
|
||||
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
|
||||
0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
|
||||
0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01,
|
||||
0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41,
|
||||
0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x43, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
|
||||
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
|
||||
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
|
||||
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x12, 0x0b,
|
||||
0x0a, 0x07, 0x47, 0x4f, 0x5f, 0x49, 0x44, 0x4c, 0x45, 0x10, 0x05, 0x22, 0x2e, 0x0a, 0x04, 0x4d,
|
||||
0x6f, 0x64, 0x65, 0x12, 0x1b, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x01, 0x20,
|
||||
0x01, 0x28, 0x08, 0x48, 0x00, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x88, 0x01, 0x01,
|
||||
0x42, 0x09, 0x0a, 0x07, 0x5f, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x22, 0x6d, 0x0a, 0x0f, 0x52,
|
||||
0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x28,
|
||||
0x0a, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65,
|
||||
0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61,
|
||||
0x73, 0x73, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65,
|
||||
0x6e, 0x70, 0x61, 0x73, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x18,
|
||||
0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
|
||||
0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x32, 0xb9, 0x01, 0x0a, 0x0e, 0x53,
|
||||
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x45, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x4c, 0x0a,
|
||||
0x04, 0x53, 0x65, 0x6e, 0x64, 0x12, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78,
|
||||
0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64,
|
||||
0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c,
|
||||
0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74,
|
||||
0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x59, 0x0a, 0x0d, 0x43,
|
||||
0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x20, 0x2e, 0x73,
|
||||
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x45, 0x6e,
|
||||
0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x20,
|
||||
0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e,
|
||||
0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
|
||||
0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
|
||||
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_signalexchange_proto_rawDescOnce sync.Once
|
||||
file_signalexchange_proto_rawDescData = file_signalexchange_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_signalexchange_proto_rawDescGZIP() []byte {
|
||||
file_signalexchange_proto_rawDescOnce.Do(func() {
|
||||
file_signalexchange_proto_rawDescData = protoimpl.X.CompressGZIP(file_signalexchange_proto_rawDescData)
|
||||
})
|
||||
return file_signalexchange_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_signalexchange_proto_enumTypes = make([]protoimpl.EnumInfo, 1)
|
||||
var file_signalexchange_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
|
||||
var file_signalexchange_proto_goTypes = []interface{}{
|
||||
(Body_Type)(0), // 0: signalexchange.Body.Type
|
||||
(*EncryptedMessage)(nil), // 1: signalexchange.EncryptedMessage
|
||||
(*Message)(nil), // 2: signalexchange.Message
|
||||
(*Body)(nil), // 3: signalexchange.Body
|
||||
(*Mode)(nil), // 4: signalexchange.Mode
|
||||
(*RosenpassConfig)(nil), // 5: signalexchange.RosenpassConfig
|
||||
}
|
||||
var file_signalexchange_proto_depIdxs = []int32{
|
||||
3, // 0: signalexchange.Message.body:type_name -> signalexchange.Body
|
||||
0, // 1: signalexchange.Body.type:type_name -> signalexchange.Body.Type
|
||||
4, // 2: signalexchange.Body.mode:type_name -> signalexchange.Mode
|
||||
5, // 3: signalexchange.Body.rosenpassConfig:type_name -> signalexchange.RosenpassConfig
|
||||
1, // 4: signalexchange.SignalExchange.Send:input_type -> signalexchange.EncryptedMessage
|
||||
1, // 5: signalexchange.SignalExchange.ConnectStream:input_type -> signalexchange.EncryptedMessage
|
||||
1, // 6: signalexchange.SignalExchange.Send:output_type -> signalexchange.EncryptedMessage
|
||||
1, // 7: signalexchange.SignalExchange.ConnectStream:output_type -> signalexchange.EncryptedMessage
|
||||
6, // [6:8] is the sub-list for method output_type
|
||||
4, // [4:6] is the sub-list for method input_type
|
||||
4, // [4:4] is the sub-list for extension type_name
|
||||
4, // [4:4] is the sub-list for extension extendee
|
||||
0, // [0:4] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_signalexchange_proto_init() }
|
||||
func file_signalexchange_proto_init() {
|
||||
if File_signalexchange_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_signalexchange_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*EncryptedMessage); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_signalexchange_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*Message); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_signalexchange_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*Body); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_signalexchange_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*Mode); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_signalexchange_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*RosenpassConfig); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
file_signalexchange_proto_msgTypes[3].OneofWrappers = []interface{}{}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_signalexchange_proto_rawDesc,
|
||||
NumEnums: 1,
|
||||
NumMessages: 5,
|
||||
NumExtensions: 0,
|
||||
NumServices: 1,
|
||||
},
|
||||
GoTypes: file_signalexchange_proto_goTypes,
|
||||
DependencyIndexes: file_signalexchange_proto_depIdxs,
|
||||
EnumInfos: file_signalexchange_proto_enumTypes,
|
||||
MessageInfos: file_signalexchange_proto_msgTypes,
|
||||
}.Build()
|
||||
File_signalexchange_proto = out.File
|
||||
file_signalexchange_proto_rawDesc = nil
|
||||
file_signalexchange_proto_goTypes = nil
|
||||
file_signalexchange_proto_depIdxs = nil
|
||||
}
|
||||
78
shared/signal/proto/signalexchange.proto
Normal file
78
shared/signal/proto/signalexchange.proto
Normal file
@@ -0,0 +1,78 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/descriptor.proto";
|
||||
|
||||
option go_package = "/proto";
|
||||
|
||||
package signalexchange;
|
||||
|
||||
service SignalExchange {
|
||||
// Synchronously connect to the Signal Exchange service offering connection candidates and waiting for connection candidates from the other party (remote peer)
|
||||
rpc Send(EncryptedMessage) returns (EncryptedMessage) {}
|
||||
// Connect to the Signal Exchange service offering connection candidates and maintain a channel for receiving candidates from the other party (remote peer)
|
||||
rpc ConnectStream(stream EncryptedMessage) returns (stream EncryptedMessage) {}
|
||||
}
|
||||
|
||||
// Used for sending through signal.
|
||||
// The body of this message is the Body message encrypted with the Wireguard private key and the remote Peer key
|
||||
message EncryptedMessage {
|
||||
|
||||
// Wireguard public key
|
||||
string key = 2;
|
||||
|
||||
// Wireguard public key of the remote peer to connect to
|
||||
string remoteKey = 3;
|
||||
|
||||
// encrypted message Body
|
||||
bytes body = 4;
|
||||
}
|
||||
|
||||
// A decrypted representation of the EncryptedMessage. Used locally before/after encryption
|
||||
message Message {
|
||||
// WireGuard public key
|
||||
string key = 2;
|
||||
|
||||
// WireGuard public key of the remote peer to connect to
|
||||
string remoteKey = 3;
|
||||
|
||||
Body body = 4;
|
||||
}
|
||||
|
||||
// Actual body of the message that can contain credentials (type OFFER/ANSWER) or connection Candidate
|
||||
// This part will be encrypted
|
||||
message Body {
|
||||
// Message type
|
||||
enum Type {
|
||||
OFFER = 0;
|
||||
ANSWER = 1;
|
||||
CANDIDATE = 2;
|
||||
MODE = 4;
|
||||
GO_IDLE = 5;
|
||||
}
|
||||
Type type = 1;
|
||||
string payload = 2;
|
||||
// wgListenPort is an actual WireGuard listen port
|
||||
uint32 wgListenPort = 3;
|
||||
string netBirdVersion = 4;
|
||||
Mode mode = 5;
|
||||
|
||||
// featuresSupported list of supported features by the client of this protocol
|
||||
repeated uint32 featuresSupported = 6;
|
||||
|
||||
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
|
||||
RosenpassConfig rosenpassConfig = 7;
|
||||
|
||||
// relayServerAddress is url of the relay server
|
||||
string relayServerAddress = 8;
|
||||
}
|
||||
|
||||
// Mode indicates a connection mode
|
||||
message Mode {
|
||||
optional bool direct = 1;
|
||||
}
|
||||
|
||||
message RosenpassConfig {
|
||||
bytes rosenpassPubKey = 1;
|
||||
// rosenpassServerAddr is an IP:port of the rosenpass service
|
||||
string rosenpassServerAddr = 2;
|
||||
}
|
||||
174
shared/signal/proto/signalexchange_grpc.pb.go
Normal file
174
shared/signal/proto/signalexchange_grpc.pb.go
Normal file
@@ -0,0 +1,174 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
|
||||
package proto
|
||||
|
||||
import (
|
||||
context "context"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7
|
||||
|
||||
// SignalExchangeClient is the client API for SignalExchange service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type SignalExchangeClient interface {
|
||||
// Synchronously connect to the Signal Exchange service offering connection candidates and waiting for connection candidates from the other party (remote peer)
|
||||
Send(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
|
||||
// Connect to the Signal Exchange service offering connection candidates and maintain a channel for receiving candidates from the other party (remote peer)
|
||||
ConnectStream(ctx context.Context, opts ...grpc.CallOption) (SignalExchange_ConnectStreamClient, error)
|
||||
}
|
||||
|
||||
type signalExchangeClient struct {
|
||||
cc grpc.ClientConnInterface
|
||||
}
|
||||
|
||||
func NewSignalExchangeClient(cc grpc.ClientConnInterface) SignalExchangeClient {
|
||||
return &signalExchangeClient{cc}
|
||||
}
|
||||
|
||||
func (c *signalExchangeClient) Send(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
|
||||
out := new(EncryptedMessage)
|
||||
err := c.cc.Invoke(ctx, "/signalexchange.SignalExchange/Send", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *signalExchangeClient) ConnectStream(ctx context.Context, opts ...grpc.CallOption) (SignalExchange_ConnectStreamClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, &SignalExchange_ServiceDesc.Streams[0], "/signalexchange.SignalExchange/ConnectStream", opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &signalExchangeConnectStreamClient{stream}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type SignalExchange_ConnectStreamClient interface {
|
||||
Send(*EncryptedMessage) error
|
||||
Recv() (*EncryptedMessage, error)
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
type signalExchangeConnectStreamClient struct {
|
||||
grpc.ClientStream
|
||||
}
|
||||
|
||||
func (x *signalExchangeConnectStreamClient) Send(m *EncryptedMessage) error {
|
||||
return x.ClientStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *signalExchangeConnectStreamClient) Recv() (*EncryptedMessage, error) {
|
||||
m := new(EncryptedMessage)
|
||||
if err := x.ClientStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// SignalExchangeServer is the server API for SignalExchange service.
|
||||
// All implementations must embed UnimplementedSignalExchangeServer
|
||||
// for forward compatibility
|
||||
type SignalExchangeServer interface {
|
||||
// Synchronously connect to the Signal Exchange service offering connection candidates and waiting for connection candidates from the other party (remote peer)
|
||||
Send(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
|
||||
// Connect to the Signal Exchange service offering connection candidates and maintain a channel for receiving candidates from the other party (remote peer)
|
||||
ConnectStream(SignalExchange_ConnectStreamServer) error
|
||||
mustEmbedUnimplementedSignalExchangeServer()
|
||||
}
|
||||
|
||||
// UnimplementedSignalExchangeServer must be embedded to have forward compatible implementations.
|
||||
type UnimplementedSignalExchangeServer struct {
|
||||
}
|
||||
|
||||
func (UnimplementedSignalExchangeServer) Send(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method Send not implemented")
|
||||
}
|
||||
func (UnimplementedSignalExchangeServer) ConnectStream(SignalExchange_ConnectStreamServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method ConnectStream not implemented")
|
||||
}
|
||||
func (UnimplementedSignalExchangeServer) mustEmbedUnimplementedSignalExchangeServer() {}
|
||||
|
||||
// UnsafeSignalExchangeServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to SignalExchangeServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeSignalExchangeServer interface {
|
||||
mustEmbedUnimplementedSignalExchangeServer()
|
||||
}
|
||||
|
||||
func RegisterSignalExchangeServer(s grpc.ServiceRegistrar, srv SignalExchangeServer) {
|
||||
s.RegisterService(&SignalExchange_ServiceDesc, srv)
|
||||
}
|
||||
|
||||
func _SignalExchange_Send_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(EncryptedMessage)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(SignalExchangeServer).Send(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/signalexchange.SignalExchange/Send",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(SignalExchangeServer).Send(ctx, req.(*EncryptedMessage))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _SignalExchange_ConnectStream_Handler(srv interface{}, stream grpc.ServerStream) error {
|
||||
return srv.(SignalExchangeServer).ConnectStream(&signalExchangeConnectStreamServer{stream})
|
||||
}
|
||||
|
||||
type SignalExchange_ConnectStreamServer interface {
|
||||
Send(*EncryptedMessage) error
|
||||
Recv() (*EncryptedMessage, error)
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
type signalExchangeConnectStreamServer struct {
|
||||
grpc.ServerStream
|
||||
}
|
||||
|
||||
func (x *signalExchangeConnectStreamServer) Send(m *EncryptedMessage) error {
|
||||
return x.ServerStream.SendMsg(m)
|
||||
}
|
||||
|
||||
func (x *signalExchangeConnectStreamServer) Recv() (*EncryptedMessage, error) {
|
||||
m := new(EncryptedMessage)
|
||||
if err := x.ServerStream.RecvMsg(m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// SignalExchange_ServiceDesc is the grpc.ServiceDesc for SignalExchange service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var SignalExchange_ServiceDesc = grpc.ServiceDesc{
|
||||
ServiceName: "signalexchange.SignalExchange",
|
||||
HandlerType: (*SignalExchangeServer)(nil),
|
||||
Methods: []grpc.MethodDesc{
|
||||
{
|
||||
MethodName: "Send",
|
||||
Handler: _SignalExchange_Send_Handler,
|
||||
},
|
||||
},
|
||||
Streams: []grpc.StreamDesc{
|
||||
{
|
||||
StreamName: "ConnectStream",
|
||||
Handler: _SignalExchange_ConnectStream_Handler,
|
||||
ServerStreams: true,
|
||||
ClientStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "signalexchange.proto",
|
||||
}
|
||||
Reference in New Issue
Block a user