Merge main branch into feature/client-metrics

This commit is contained in:
Zoltán Papp
2026-01-21 11:22:01 +01:00
162 changed files with 16301 additions and 2734 deletions

View File

@@ -14,6 +14,7 @@ import (
type Client interface {
io.Closer
Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)

View File

@@ -18,12 +18,13 @@ import (
"google.golang.org/grpc/status"
"github.com/netbirdio/management-integrations/integrations"
ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/encryption"
@@ -92,6 +93,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersManger := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManger)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore)
@@ -117,8 +119,8 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManger), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManger), config)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
@@ -129,7 +131,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -12,6 +12,7 @@ import (
gstatus "google.golang.org/grpc/status"
"github.com/cenkalti/backoff/v4"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
@@ -118,8 +119,26 @@ func (c *GrpcClient) ready() bool {
// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
backOff := defaultBackoff(ctx)
return c.withMgmtStream(ctx, func(ctx context.Context, serverPubKey wgtypes.Key) error {
return c.handleSyncStream(ctx, serverPubKey, sysInfo, msgHandler)
})
}
// Job wraps the real client's Job endpoint call and takes care of retries and encryption/decryption of messages
// Blocking request. The result will be sent via msgHandler callback function
func (c *GrpcClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error {
return c.withMgmtStream(ctx, func(ctx context.Context, serverPubKey wgtypes.Key) error {
return c.handleJobStream(ctx, serverPubKey, msgHandler)
})
}
// withMgmtStream runs a streaming operation against the ManagementService
// It takes care of retries, connection readiness, and fetching server public key.
func (c *GrpcClient) withMgmtStream(
ctx context.Context,
handler func(ctx context.Context, serverPubKey wgtypes.Key) error,
) error {
backOff := defaultBackoff(ctx)
operation := func() error {
log.Debugf("management connection state %v", c.conn.GetState())
connState := c.conn.GetState()
@@ -137,7 +156,7 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return err
}
return c.handleStream(ctx, *serverPubKey, sysInfo, msgHandler, backOff)
return handler(ctx, *serverPubKey)
}
err := backoff.Retry(operation, backOff)
@@ -148,12 +167,153 @@ func (c *GrpcClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return err
}
func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info,
msgHandler func(msg *proto.SyncResponse) error, backOff backoff.BackOff) error {
func (c *GrpcClient) handleJobStream(
ctx context.Context,
serverPubKey wgtypes.Key,
msgHandler func(msg *proto.JobRequest) *proto.JobResponse,
) error {
ctx, cancelStream := context.WithCancel(ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, serverPubKey, sysInfo)
stream, err := c.realClient.Job(ctx)
if err != nil {
log.Errorf("failed to open job stream: %v", err)
return err
}
// Handshake with the server
if err := c.sendHandshake(ctx, stream, serverPubKey); err != nil {
return err
}
log.Debug("job stream handshake sent successfully")
// Main loop: receive, process, respond
for {
jobReq, err := c.receiveJobRequest(ctx, stream, serverPubKey)
if err != nil {
if s, ok := gstatus.FromError(err); ok {
switch s.Code() {
case codes.PermissionDenied:
c.notifyDisconnected(err)
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return err
case codes.Unimplemented:
log.Warn("Job feature is not supported by the current management server version. " +
"Please update the management service to use this feature.")
return nil
default:
c.notifyDisconnected(err)
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
} else {
// non-gRPC error
c.notifyDisconnected(err)
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
}
if jobReq == nil || len(jobReq.ID) == 0 {
log.Debug("received unknown or empty job request, skipping")
continue
}
log.Infof("received a new job from the management server (ID: %s)", jobReq.ID)
jobResp := c.processJobRequest(ctx, jobReq, msgHandler)
if err := c.sendJobResponse(ctx, stream, serverPubKey, jobResp); err != nil {
return err
}
}
}
// sendHandshake sends the initial handshake message
func (c *GrpcClient) sendHandshake(ctx context.Context, stream proto.ManagementService_JobClient, serverPubKey wgtypes.Key) error {
handshakeReq := &proto.JobRequest{
ID: []byte(uuid.New().String()),
}
encHello, err := encryption.EncryptMessage(serverPubKey, c.key, handshakeReq)
if err != nil {
log.Errorf("failed to encrypt handshake message: %v", err)
return err
}
return stream.Send(&proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encHello,
})
}
// receiveJobRequest waits for and decrypts a job request
func (c *GrpcClient) receiveJobRequest(
ctx context.Context,
stream proto.ManagementService_JobClient,
serverPubKey wgtypes.Key,
) (*proto.JobRequest, error) {
encryptedMsg, err := stream.Recv()
if err != nil {
return nil, err
}
jobReq := &proto.JobRequest{}
if err := encryption.DecryptMessage(serverPubKey, c.key, encryptedMsg.Body, jobReq); err != nil {
log.Warnf("failed to decrypt job request: %v", err)
return nil, err
}
return jobReq, nil
}
// processJobRequest executes the handler and ensures a valid response
func (c *GrpcClient) processJobRequest(
ctx context.Context,
jobReq *proto.JobRequest,
msgHandler func(msg *proto.JobRequest) *proto.JobResponse,
) *proto.JobResponse {
jobResp := msgHandler(jobReq)
if jobResp == nil {
jobResp = &proto.JobResponse{
ID: jobReq.ID,
Status: proto.JobStatus_failed,
Reason: []byte("handler returned nil response"),
}
log.Warnf("job handler returned nil for job %s", string(jobReq.ID))
}
return jobResp
}
// sendJobResponse encrypts and sends a job response
func (c *GrpcClient) sendJobResponse(
ctx context.Context,
stream proto.ManagementService_JobClient,
serverPubKey wgtypes.Key,
resp *proto.JobResponse,
) error {
encResp, err := encryption.EncryptMessage(serverPubKey, c.key, resp)
if err != nil {
log.Errorf("failed to encrypt job response for job %s: %v", string(resp.ID), err)
return err
}
if err := stream.Send(&proto.EncryptedMessage{
WgPubKey: c.key.PublicKey().String(),
Body: encResp,
}); err != nil {
log.Errorf("failed to send job response for job %s: %v", string(resp.ID), err)
return err
}
log.Infof("job response sent for job %s (status: %s)", string(resp.ID), resp.Status.String())
return nil
}
func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error {
ctx, cancelStream := context.WithCancel(ctx)
defer cancelStream()
stream, err := c.connectToSyncStream(ctx, serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
@@ -166,20 +326,22 @@ func (c *GrpcClient) handleStream(ctx context.Context, serverPubKey wgtypes.Key,
c.notifyConnected()
// blocking until error
err = c.receiveEvents(stream, serverPubKey, msgHandler)
// we need this reset because after a successful connection and a consequent error, backoff lib doesn't
// reset times and next try will start with a long delay
backOff.Reset()
err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler)
if err != nil {
c.notifyDisconnected(err)
s, _ := gstatus.FromError(err)
switch s.Code() {
case codes.PermissionDenied:
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
default:
if s, ok := gstatus.FromError(err); ok {
switch s.Code() {
case codes.PermissionDenied:
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
case codes.Canceled:
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
default:
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
} else {
// non-gRPC error
log.Warnf("disconnected from the Management service but will retry silently. Reason: %v", err)
return err
}
@@ -198,7 +360,7 @@ func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, err
ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey, sysInfo)
stream, err := c.connectToSyncStream(ctx, *serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
return nil, err
@@ -231,7 +393,7 @@ func (c *GrpcClient) GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, err
return decryptedResp.GetNetworkMap(), nil
}
func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
func (c *GrpcClient) connectToSyncStream(ctx context.Context, serverPubKey wgtypes.Key, sysInfo *system.Info) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{Meta: infoToMetaData(sysInfo)}
myPrivateKey := c.key
@@ -250,7 +412,7 @@ func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.K
return sync, nil
}
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
func (c *GrpcClient) receiveUpdatesEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {
for {
update, err := stream.Recv()
if err == io.EOF {

View File

@@ -20,6 +20,7 @@ type MockClient struct {
GetPKCEAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.PKCEAuthorizationFlow, error)
SyncMetaFunc func(sysInfo *system.Info) error
LogoutFunc func() error
JobFunc func(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
}
func (m *MockClient) IsHealthy() bool {
@@ -40,6 +41,13 @@ func (m *MockClient) Sync(ctx context.Context, sysInfo *system.Info, msgHandler
return m.SyncFunc(ctx, sysInfo, msgHandler)
}
func (m *MockClient) Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error {
if m.JobFunc == nil {
return nil
}
return m.JobFunc(ctx, msgHandler)
}
func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
if m.GetServerPublicKeyFunc == nil {
return nil, nil

View File

@@ -59,9 +59,13 @@ type Client struct {
Routes *RoutesAPI
// DNS NetBird DNS APIs
// see more: https://docs.netbird.io/api/resources/routes
// see more: https://docs.netbird.io/api/resources/dns
DNS *DNSAPI
// DNSZones NetBird DNS Zones APIs
// see more: https://docs.netbird.io/api/resources/dns-zones
DNSZones *DNSZonesAPI
// GeoLocation NetBird Geo Location APIs
// see more: https://docs.netbird.io/api/resources/geo-locations
GeoLocation *GeoLocationAPI
@@ -113,6 +117,7 @@ func (c *Client) initialize() {
c.Networks = &NetworksAPI{c}
c.Routes = &RoutesAPI{c}
c.DNS = &DNSAPI{c}
c.DNSZones = &DNSZonesAPI{c}
c.GeoLocation = &GeoLocationAPI{c}
c.Events = &EventsAPI{c}
}

View File

@@ -0,0 +1,170 @@
package rest
import (
"bytes"
"context"
"encoding/json"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// DNSZonesAPI APIs for DNS Zones Management, do not use directly
type DNSZonesAPI struct {
c *Client
}
// ListZones list all DNS zones
// See more: https://docs.netbird.io/api/resources/dns-zones#list-all-dns-zones
func (a *DNSZonesAPI) ListZones(ctx context.Context) ([]api.Zone, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.Zone](resp)
return ret, err
}
// GetZone get DNS zone info
// See more: https://docs.netbird.io/api/resources/dns-zones#retrieve-a-dns-zone
func (a *DNSZonesAPI) GetZone(ctx context.Context, zoneID string) (*api.Zone, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones/"+zoneID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Zone](resp)
return &ret, err
}
// CreateZone create new DNS zone
// See more: https://docs.netbird.io/api/resources/dns-zones#create-a-dns-zone
func (a *DNSZonesAPI) CreateZone(ctx context.Context, request api.PostApiDnsZonesJSONRequestBody) (*api.Zone, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/zones", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Zone](resp)
return &ret, err
}
// UpdateZone update DNS zone info
// See more: https://docs.netbird.io/api/resources/dns-zones#update-a-dns-zone
func (a *DNSZonesAPI) UpdateZone(ctx context.Context, zoneID string, request api.PutApiDnsZonesZoneIdJSONRequestBody) (*api.Zone, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/zones/"+zoneID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.Zone](resp)
return &ret, err
}
// DeleteZone delete DNS zone
// See more: https://docs.netbird.io/api/resources/dns-zones#delete-a-dns-zone
func (a *DNSZonesAPI) DeleteZone(ctx context.Context, zoneID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/zones/"+zoneID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}
// ListRecords list all DNS records in a zone
// See more: https://docs.netbird.io/api/resources/dns-zones#list-all-dns-records
func (a *DNSZonesAPI) ListRecords(ctx context.Context, zoneID string) ([]api.DNSRecord, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones/"+zoneID+"/records", nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[[]api.DNSRecord](resp)
return ret, err
}
// GetRecord get DNS record info
// See more: https://docs.netbird.io/api/resources/dns-zones#retrieve-a-dns-record
func (a *DNSZonesAPI) GetRecord(ctx context.Context, zoneID, recordID string) (*api.DNSRecord, error) {
resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones/"+zoneID+"/records/"+recordID, nil, nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.DNSRecord](resp)
return &ret, err
}
// CreateRecord create new DNS record in a zone
// See more: https://docs.netbird.io/api/resources/dns-zones#create-a-dns-record
func (a *DNSZonesAPI) CreateRecord(ctx context.Context, zoneID string, request api.PostApiDnsZonesZoneIdRecordsJSONRequestBody) (*api.DNSRecord, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/zones/"+zoneID+"/records", bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.DNSRecord](resp)
return &ret, err
}
// UpdateRecord update DNS record info
// See more: https://docs.netbird.io/api/resources/dns-zones#update-a-dns-record
func (a *DNSZonesAPI) UpdateRecord(ctx context.Context, zoneID, recordID string, request api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody) (*api.DNSRecord, error) {
requestBytes, err := json.Marshal(request)
if err != nil {
return nil, err
}
resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/zones/"+zoneID+"/records/"+recordID, bytes.NewReader(requestBytes), nil)
if err != nil {
return nil, err
}
if resp.Body != nil {
defer resp.Body.Close()
}
ret, err := parseResponse[api.DNSRecord](resp)
return &ret, err
}
// DeleteRecord delete DNS record
// See more: https://docs.netbird.io/api/resources/dns-zones#delete-a-dns-record
func (a *DNSZonesAPI) DeleteRecord(ctx context.Context, zoneID, recordID string) error {
resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/zones/"+zoneID+"/records/"+recordID, nil, nil)
if err != nil {
return err
}
if resp.Body != nil {
defer resp.Body.Close()
}
return nil
}

View File

@@ -0,0 +1,460 @@
//go:build integration
// +build integration
package rest_test
import (
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/client/rest"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
var (
testZone = api.Zone{
Id: "zone123",
Name: "test-zone",
Domain: "example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
}
testDNSRecord = api.DNSRecord{
Id: "record123",
Name: "www",
Content: "192.168.1.1",
Type: api.DNSRecordTypeA,
Ttl: 300,
}
)
func TestDNSZone_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.Zone{testZone})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.ListZones(context.Background())
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testZone, ret[0])
})
}
func TestDNSZone_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.ListZones(context.Background())
assert.Error(t, err)
assert.Equal(t, "No", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSZone_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testZone)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.GetZone(context.Background(), "zone123")
require.NoError(t, err)
assert.Equal(t, testZone, *ret)
})
}
func TestDNSZone_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.GetZone(context.Background(), "zone123")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSZone_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiDnsZonesJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "test-zone", req.Name)
assert.Equal(t, "example.com", req.Domain)
retBytes, _ := json.Marshal(testZone)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
enabled := true
ret, err := c.DNSZones.CreateZone(context.Background(), api.PostApiDnsZonesJSONRequestBody{
Name: "test-zone",
Domain: "example.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
})
require.NoError(t, err)
assert.Equal(t, testZone, *ret)
})
}
func TestDNSZone_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid request", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.CreateZone(context.Background(), api.PostApiDnsZonesJSONRequestBody{
Name: "test-zone",
Domain: "example.com",
})
assert.Error(t, err)
assert.Equal(t, "Invalid request", err.Error())
assert.Nil(t, ret)
})
}
func TestDNSZone_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiDnsZonesZoneIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "updated-zone", req.Name)
retBytes, _ := json.Marshal(testZone)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
enabled := true
ret, err := c.DNSZones.UpdateZone(context.Background(), "zone123", api.PutApiDnsZonesZoneIdJSONRequestBody{
Name: "updated-zone",
Domain: "example.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{"group1"},
})
require.NoError(t, err)
assert.Equal(t, testZone, *ret)
})
}
func TestDNSZone_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid request", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.UpdateZone(context.Background(), "zone123", api.PutApiDnsZonesZoneIdJSONRequestBody{
Name: "updated-zone",
Domain: "example.com",
})
assert.Error(t, err)
assert.Equal(t, "Invalid request", err.Error())
assert.Nil(t, ret)
})
}
func TestDNSZone_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.DNSZones.DeleteZone(context.Background(), "zone123")
require.NoError(t, err)
})
}
func TestDNSZone_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.DNSZones.DeleteZone(context.Background(), "zone123")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestDNSRecord_List_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal([]api.DNSRecord{testDNSRecord})
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.ListRecords(context.Background(), "zone123")
require.NoError(t, err)
assert.Len(t, ret, 1)
assert.Equal(t, testDNSRecord, ret[0])
})
}
func TestDNSRecord_List_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Zone not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.ListRecords(context.Background(), "zone123")
assert.Error(t, err)
assert.Equal(t, "Zone not found", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSRecord_Get_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "GET", r.Method)
retBytes, _ := json.Marshal(testDNSRecord)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.GetRecord(context.Background(), "zone123", "record123")
require.NoError(t, err)
assert.Equal(t, testDNSRecord, *ret)
})
}
func TestDNSRecord_Get_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.GetRecord(context.Background(), "zone123", "record123")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
assert.Empty(t, ret)
})
}
func TestDNSRecord_Create_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "POST", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "www", req.Name)
assert.Equal(t, "192.168.1.1", req.Content)
assert.Equal(t, api.DNSRecordTypeA, req.Type)
retBytes, _ := json.Marshal(testDNSRecord)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.CreateRecord(context.Background(), "zone123", api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "www",
Content: "192.168.1.1",
Type: api.DNSRecordTypeA,
Ttl: 300,
})
require.NoError(t, err)
assert.Equal(t, testDNSRecord, *ret)
})
}
func TestDNSRecord_Create_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid record", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.CreateRecord(context.Background(), "zone123", api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{
Name: "www",
Content: "192.168.1.1",
Type: api.DNSRecordTypeA,
Ttl: 300,
})
assert.Error(t, err)
assert.Equal(t, "Invalid record", err.Error())
assert.Nil(t, ret)
})
}
func TestDNSRecord_Update_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "PUT", r.Method)
reqBytes, err := io.ReadAll(r.Body)
require.NoError(t, err)
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
err = json.Unmarshal(reqBytes, &req)
require.NoError(t, err)
assert.Equal(t, "api", req.Name)
assert.Equal(t, "192.168.1.2", req.Content)
retBytes, _ := json.Marshal(testDNSRecord)
_, err = w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.UpdateRecord(context.Background(), "zone123", "record123", api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "api",
Content: "192.168.1.2",
Type: api.DNSRecordTypeA,
Ttl: 300,
})
require.NoError(t, err)
assert.Equal(t, testDNSRecord, *ret)
})
}
func TestDNSRecord_Update_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid record", Code: 400})
w.WriteHeader(400)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
ret, err := c.DNSZones.UpdateRecord(context.Background(), "zone123", "record123", api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{
Name: "api",
Content: "192.168.1.2",
Type: api.DNSRecordTypeA,
Ttl: 300,
})
assert.Error(t, err)
assert.Equal(t, "Invalid record", err.Error())
assert.Nil(t, ret)
})
}
func TestDNSRecord_Delete_200(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "DELETE", r.Method)
w.WriteHeader(200)
})
err := c.DNSZones.DeleteRecord(context.Background(), "zone123", "record123")
require.NoError(t, err)
})
}
func TestDNSRecord_Delete_Err(t *testing.T) {
withMockClient(func(c *rest.Client, mux *http.ServeMux) {
mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) {
retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404})
w.WriteHeader(404)
_, err := w.Write(retBytes)
require.NoError(t, err)
})
err := c.DNSZones.DeleteRecord(context.Background(), "zone123", "record123")
assert.Error(t, err)
assert.Equal(t, "Not found", err.Error())
})
}
func TestDNSZones_Integration(t *testing.T) {
enabled := true
zoneReq := api.ZoneRequest{
Name: "test-zone",
Domain: "test.example.com",
Enabled: &enabled,
EnableSearchDomain: false,
DistributionGroups: []string{"cs1tnh0hhcjnqoiuebeg"},
}
recordReq := api.DNSRecordRequest{
Name: "api.test.example.com",
Content: "192.168.1.100",
Type: api.DNSRecordTypeA,
Ttl: 300,
}
withBlackBoxServer(t, func(c *rest.Client) {
zone, err := c.DNSZones.CreateZone(context.Background(), zoneReq)
require.NoError(t, err)
assert.Equal(t, "test-zone", zone.Name)
assert.Equal(t, "test.example.com", zone.Domain)
zones, err := c.DNSZones.ListZones(context.Background())
require.NoError(t, err)
assert.Equal(t, *zone, zones[0])
getZone, err := c.DNSZones.GetZone(context.Background(), zone.Id)
require.NoError(t, err)
assert.Equal(t, *zone, *getZone)
zoneReq.Name = "updated-zone"
updatedZone, err := c.DNSZones.UpdateZone(context.Background(), zone.Id, zoneReq)
require.NoError(t, err)
assert.Equal(t, "updated-zone", updatedZone.Name)
record, err := c.DNSZones.CreateRecord(context.Background(), zone.Id, recordReq)
require.NoError(t, err)
assert.Equal(t, "api.test.example.com", record.Name)
assert.Equal(t, "192.168.1.100", record.Content)
records, err := c.DNSZones.ListRecords(context.Background(), zone.Id)
require.NoError(t, err)
assert.Equal(t, *record, records[0])
getRecord, err := c.DNSZones.GetRecord(context.Background(), zone.Id, record.Id)
require.NoError(t, err)
assert.Equal(t, *record, *getRecord)
recordReq.Name = "www.test.example.com"
updatedRecord, err := c.DNSZones.UpdateRecord(context.Background(), zone.Id, record.Id, recordReq)
require.NoError(t, err)
assert.Equal(t, "www.test.example.com", updatedRecord.Name)
err = c.DNSZones.DeleteRecord(context.Background(), zone.Id, record.Id)
require.NoError(t, err)
records, err = c.DNSZones.ListRecords(context.Background(), zone.Id)
require.NoError(t, err)
assert.Len(t, records, 0)
err = c.DNSZones.DeleteZone(context.Background(), zone.Id)
require.NoError(t, err)
zones, err = c.DNSZones.ListZones(context.Background())
require.NoError(t, err)
assert.Len(t, zones, 0)
})
}