[management] Add custom dns zones (#4849)

This commit is contained in:
Bethuel Mmbaga
2026-01-16 10:12:05 +01:00
committed by GitHub
parent 291e640b28
commit 067c77e49e
36 changed files with 4837 additions and 63 deletions

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
@@ -175,7 +176,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
@@ -197,6 +198,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return fmt.Errorf("failed to get account zones: %v", err)
}
for _, peer := range account.Peers {
if !c.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
@@ -223,9 +230,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -318,7 +325,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
customZone := account.GetPeersCustomZone(ctx, dnsDomain)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
@@ -335,12 +342,18 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err
}
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return err
}
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -434,7 +447,14 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
@@ -445,11 +465,11 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics)
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -472,7 +492,8 @@ func (c *Controller) getPeerNetworkMapExp(
accountId string,
peerId string,
validatedPeers map[string]struct{},
customZone nbdns.CustomZone,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(ctx, accountId)
@@ -483,7 +504,7 @@ func (c *Controller) getPeerNetworkMapExp(
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
@@ -798,7 +819,15 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
if err != nil {
return nil, err
}
customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
@@ -809,11 +838,11 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -3,6 +3,7 @@ package controller
import (
"context"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -14,6 +15,7 @@ type Repository interface {
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error)
}
type repository struct {
@@ -47,3 +49,7 @@ func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerID
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}
func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) {
return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
}

View File

@@ -0,0 +1,13 @@
package zones
import (
"context"
)
type Manager interface {
GetAllZones(ctx context.Context, accountID, userID string) ([]*Zone, error)
GetZone(ctx context.Context, accountID, userID, zone string) (*Zone, error)
CreateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
UpdateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error)
DeleteZone(ctx context.Context, accountID, userID, zoneID string) error
}

View File

@@ -0,0 +1,161 @@
package manager
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/zones"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager zones.Manager
}
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiZones := make([]*api.Zone, 0, len(allZones))
for _, zone := range allZones {
apiZones = append(apiZones, zone.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiZones)
}
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req api.PostApiDnsZonesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
zone := new(zones.Zone)
zone.FromAPIRequest(&req)
if err = zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
createdZone, err := h.manager.CreateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
}
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
zone, err := h.manager.GetZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
}
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
var req api.PutApiDnsZonesZoneIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
zone := new(zones.Zone)
zone.FromAPIRequest(&req)
zone.ID = zoneID
if err = zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
updatedZone, err := h.manager.UpdateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
}
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -0,0 +1,229 @@
package manager
import (
"context"
"fmt"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
dnsDomain string
}
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
dnsDomain: dnsDomain,
}
}
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID)
}
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID)
}
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil {
return nil, err
}
zone = zones.NewZone(accountID, zone.Name, zone.Domain, zone.Enabled, zone.EnableSearchDomain, zone.DistributionGroups)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingZone, err := transaction.GetZoneByDomain(ctx, accountID, zone.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing zone: %w", err)
}
}
if existingZone != nil {
return status.Errorf(status.AlreadyExists, "zone with domain %s already exists", zone.Domain)
}
for _, groupID := range zone.DistributionGroups {
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
}
if err = transaction.CreateZone(ctx, zone); err != nil {
return fmt.Errorf("failed to create zone: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneCreated, zone.EventMeta())
return zone, nil
}
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID)
if err != nil {
return nil, fmt.Errorf("failed to get zone: %w", err)
}
if zone.Domain != updatedZone.Domain {
return nil, status.Errorf(status.InvalidArgument, "zone domain cannot be updated")
}
zone.Name = updatedZone.Name
zone.Enabled = updatedZone.Enabled
zone.EnableSearchDomain = updatedZone.EnableSearchDomain
zone.DistributionGroups = updatedZone.DistributionGroups
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range zone.DistributionGroups {
_, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
return status.Errorf(status.InvalidArgument, "%s", err.Error())
}
}
if err = transaction.UpdateZone(ctx, zone); err != nil {
return fmt.Errorf("failed to update zone: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return zone, nil
}
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
records, err := transaction.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get records: %w", err)
}
err = transaction.DeleteZoneDNSRecords(ctx, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to delete zone dns records: %w", err)
}
err = transaction.DeleteZone(ctx, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to delete zone: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
for _, record := range records {
eventsToStore = append(eventsToStore, func() {
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordDeleted, meta)
})
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, zoneID, accountID, activity.DNSZoneDeleted, zone.EventMeta())
})
return nil
})
if err != nil {
return err
}
for _, event := range eventsToStore {
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
func (m *managerImpl) validateZoneDomainConflict(ctx context.Context, accountID, domain string) error {
if m.dnsDomain != "" && m.dnsDomain == domain {
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
}
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
if settings.DNSDomain != "" && settings.DNSDomain == domain {
return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain)
}
return nil
}

View File

@@ -0,0 +1,553 @@
package manager
import (
"context"
"fmt"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
testZoneID = "test-zone-id"
testGroupID = "test-group-id"
testDNSDomain = "netbird.selfhosted"
)
func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
Name: "Test Group",
},
},
})
require.NoError(t, err)
ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
dnsDomain: testDNSDomain,
}
return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup
}
func TestManagerImpl_GetAllZones(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone1 := zones.NewZone(testAccountID, "Zone 1", "zone1.example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone1)
require.NoError(t, err)
zone2 := zones.NewZone(testAccountID, "Zone 2", "zone2.example.com", false, false, []string{testGroupID})
err = testStore.CreateZone(ctx, zone2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.NoError(t, err)
assert.Len(t, result, 2)
assert.Equal(t, zone1.ID, result[0].ID)
assert.Equal(t, zone2.ID, result[1].ID)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_GetZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone := zones.NewZone(testAccountID, "Test Zone", "test.example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Equal(t, zone.ID, result.ID)
assert.Equal(t, zone.Name, result.Name)
assert.Equal(t, zone.Domain, result.Domain)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
}
func TestManagerImpl_CreateZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSZoneCreated, activityID)
}
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.NoError(t, err)
assert.NotNil(t, result)
assert.NotEmpty(t, result.ID)
assert.Equal(t, testAccountID, result.AccountID)
assert.Equal(t, inputZone.Name, result.Name)
assert.Equal(t, inputZone.Domain, result.Domain)
assert.Equal(t, inputZone.Enabled, result.Enabled)
assert.Equal(t, inputZone.EnableSearchDomain, result.EnableSearchDomain)
assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("invalid group", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "new.example.com",
DistributionGroups: []string{"invalid-group"},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
})
t.Run("duplicate domain", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingZone := zones.NewZone(testAccountID, "Existing Zone", "duplicate.example.com", true, false, []string{testGroupID})
err := testStore.CreateZone(ctx, existingZone)
require.NoError(t, err)
inputZone := &zones.Zone{
Name: "New Zone",
Domain: "duplicate.example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "zone with domain duplicate.example.com already exists")
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.AlreadyExists, s.Type())
})
t.Run("peer DNS domain conflict", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
account, err := testStore.GetAccount(ctx, testAccountID)
require.NoError(t, err)
account.Settings.DNSDomain = "peers.example.com"
err = testStore.SaveAccount(ctx, account)
require.NoError(t, err)
inputZone := &zones.Zone{
Name: "Test Zone",
Domain: "peers.example.com",
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "zone domain peers.example.com conflicts with peer DNS domain")
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.InvalidArgument, s.Type())
})
t.Run("default DNS domain conflict", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputZone := &zones.Zone{
Name: "Test Zone",
Domain: testDNSDomain,
Enabled: true,
EnableSearchDomain: false,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), fmt.Sprintf("zone domain %s conflicts with peer DNS domain", testDNSDomain))
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.InvalidArgument, s.Type())
})
}
func TestManagerImpl_UpdateZone(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingZone := zones.NewZone(testAccountID, "Old Name", "example.com", false, false, []string{testGroupID})
err := testStore.CreateZone(ctx, existingZone)
require.NoError(t, err)
updatedZone := &zones.Zone{
ID: existingZone.ID,
Name: "Updated Name",
Domain: "example.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, existingZone.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSZoneUpdated, activityID)
}
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, updatedZone.Name, result.Name)
assert.Equal(t, updatedZone.Enabled, result.Enabled)
assert.Equal(t, updatedZone.EnableSearchDomain, result.EnableSearchDomain)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
})
t.Run("domain change not allowed", func(t *testing.T) {
manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingZone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, existingZone)
require.NoError(t, err)
updatedZone := &zones.Zone{
ID: existingZone.ID,
Name: "Test Zone",
Domain: "different.com",
Enabled: true,
EnableSearchDomain: true,
DistributionGroups: []string{testGroupID},
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "zone domain cannot be updated")
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.InvalidArgument, s.Type())
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedZone := &zones.Zone{
ID: testZoneID,
Name: "Updated Name",
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedZone := &zones.Zone{
ID: "non-existent-zone",
Name: "Updated Name",
Domain: "example.com",
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_DeleteZone(t *testing.T) {
ctx := context.Background()
t.Run("success with records", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = testStore.CreateDNSRecord(ctx, record1)
require.NoError(t, err)
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCallCount := 0
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCallCount++
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
}
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Equal(t, 3, storeEventCallCount)
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
require.Error(t, err)
zoneRecords, err := testStore.GetZoneDNSRecords(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
require.NoError(t, err)
assert.Empty(t, zoneRecords)
})
t.Run("success without records", func(t *testing.T) {
manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err := testStore.CreateZone(ctx, zone)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, zone.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSZoneDeleted, activityID)
}
err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
_, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID)
require.Error(t, err)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("zone not found", func(t *testing.T) {
manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
require.Error(t, err)
})
}

View File

@@ -0,0 +1,13 @@
package records
import (
"context"
)
type Manager interface {
GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*Record, error)
GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*Record, error)
CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
UpdateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error)
DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error
}

View File

@@ -0,0 +1,191 @@
package manager
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager records.Manager
}
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
allRecords, err := h.manager.GetAllRecords(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiRecords := make([]*api.DNSRecord, 0, len(allRecords))
for _, record := range allRecords {
apiRecords = append(apiRecords, record.ToAPIResponse())
}
util.WriteJSONObject(r.Context(), w, apiRecords)
}
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
record := new(records.Record)
record.FromAPIRequest(&req)
if err = record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
createdRecord, err := h.manager.CreateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
}
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
recordID := mux.Vars(r)["recordId"]
if recordID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
return
}
record, err := h.manager.GetRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
}
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
recordID := mux.Vars(r)["recordId"]
if recordID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
return
}
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
record := new(records.Record)
record.FromAPIRequest(&req)
record.ID = recordID
if err = record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
updatedRecord, err := h.manager.UpdateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
}
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
recordID := mux.Vars(r)["recordId"]
if recordID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w)
return
}
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}

View File

@@ -0,0 +1,236 @@
package manager
import (
"context"
"fmt"
"strings"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
)
type managerImpl struct {
store store.Store
accountManager account.Manager
permissionsManager permissions.Manager
}
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager {
return &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
}
}
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID)
}
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID)
}
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone
record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL)
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
err = validateRecordConflicts(ctx, transaction, zone, record)
if err != nil {
return err
}
if err = transaction.CreateDNSRecord(ctx, record); err != nil {
return fmt.Errorf("failed to create dns record: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return record, nil
}
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
var zone *zones.Zone
var record *records.Record
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, updatedRecord.ID)
if err != nil {
return fmt.Errorf("failed to get record: %w", err)
}
hasChanges := record.Name != updatedRecord.Name || record.Type != updatedRecord.Type || record.Content != updatedRecord.Content
record.Name = updatedRecord.Name
record.Type = updatedRecord.Type
record.Content = updatedRecord.Content
record.TTL = updatedRecord.TTL
if hasChanges {
if err = validateRecordConflicts(ctx, transaction, zone, record); err != nil {
return err
}
}
if err = transaction.UpdateDNSRecord(ctx, record); err != nil {
return fmt.Errorf("failed to update dns record: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return record, nil
}
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !ok {
return status.NewPermissionDeniedError()
}
var record *records.Record
var zone *zones.Zone
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID)
if err != nil {
return fmt.Errorf("failed to get zone: %w", err)
}
record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, recordID)
if err != nil {
return fmt.Errorf("failed to get record: %w", err)
}
err = transaction.DeleteDNSRecord(ctx, accountID, zoneID, recordID)
if err != nil {
return fmt.Errorf("failed to delete dns record: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
})
if err != nil {
return err
}
meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil
}
// validateRecordConflicts checks for duplicate records and CNAME conflicts
func validateRecordConflicts(ctx context.Context, transaction store.Store, zone *zones.Zone, record *records.Record) error {
if record.Name != zone.Domain && !strings.HasSuffix(record.Name, "."+zone.Domain) {
return status.Errorf(status.InvalidArgument, "record name does not belong to zone")
}
existingRecords, err := transaction.GetZoneDNSRecordsByName(ctx, store.LockingStrengthNone, zone.AccountID, zone.ID, record.Name)
if err != nil {
return fmt.Errorf("failed to check existing records: %w", err)
}
for _, existing := range existingRecords {
if existing.ID == record.ID {
continue
}
if existing.Type == record.Type && existing.Content == record.Content {
return status.Errorf(status.AlreadyExists, "identical record already exists")
}
if record.Type == records.RecordTypeCNAME || existing.Type == records.RecordTypeCNAME {
return status.Errorf(status.InvalidArgument,
"An A, AAAA, or CNAME record with name %s already exists", record.Name)
}
}
return nil
}

View File

@@ -0,0 +1,573 @@
package manager
import (
"context"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
const (
testAccountID = "test-account-id"
testUserID = "test-user-id"
testRecordID = "test-record-id"
testGroupID = "test-group-id"
)
func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) {
t.Helper()
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
Name: "Test Group",
},
},
})
require.NoError(t, err)
zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID})
err = testStore.CreateZone(ctx, zone)
require.NoError(t, err)
ctrl := gomock.NewController(t)
mockAccountManager := &mock_server.MockAccountManager{}
mockPermissionsManager := permissions.NewMockManager(ctrl)
manager := &managerImpl{
store: testStore,
accountManager: mockAccountManager,
permissionsManager: mockPermissionsManager,
}
return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup
}
func TestManagerImpl_GetAllRecords(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record1)
require.NoError(t, err)
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
assert.Len(t, result, 2)
assert.Equal(t, record1.ID, result[0].ID)
assert.Equal(t, record2.ID, result[1].ID)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("permission validation error", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
assert.Nil(t, result)
})
}
func TestManagerImpl_GetRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err)
assert.Equal(t, record.ID, result.ID)
assert.Equal(t, record.Name, result.Name)
assert.Equal(t, record.Type, result.Type)
assert.Equal(t, record.Content, result.Content)
assert.Equal(t, record.TTL, result.TTL)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
}
func TestManagerImpl_CreateRecord(t *testing.T) {
ctx := context.Background()
t.Run("success - A record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordCreated, activityID)
}
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.NotEmpty(t, result.ID)
assert.Equal(t, testAccountID, result.AccountID)
assert.Equal(t, zone.ID, result.ZoneID)
assert.Equal(t, inputRecord.Name, result.Name)
assert.Equal(t, inputRecord.Type, result.Type)
assert.Equal(t, inputRecord.Content, result.Content)
assert.Equal(t, inputRecord.TTL, result.TTL)
})
t.Run("success - AAAA record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "ipv6.example.com",
Type: records.RecordTypeAAAA,
Content: "2001:db8::1",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordCreated, activityID)
}
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, inputRecord.Type, result.Type)
assert.Equal(t, inputRecord.Content, result.Content)
})
t.Run("success - CNAME record", func(t *testing.T) {
manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "www.example.com",
Type: records.RecordTypeCNAME,
Content: "example.com",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordCreated, activityID)
}
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, inputRecord.Type, result.Type)
assert.Equal(t, inputRecord.Content, result.Content)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record name not in zone", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
inputRecord := &records.Record{
Name: "api.different.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "does not belong to zone")
})
t.Run("duplicate record", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "identical record already exists")
})
t.Run("CNAME conflict with existing A record", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
inputRecord := &records.Record{
Name: "api.example.com",
Type: records.RecordTypeCNAME,
Content: "example.com",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "already exists")
})
}
func TestManagerImpl_UpdateRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
updatedRecord := &records.Record{
ID: existingRecord.ID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100", // Changed IP
TTL: 600, // Changed TTL
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, existingRecord.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordUpdated, activityID)
}
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, updatedRecord.Content, result.Content)
assert.Equal(t, updatedRecord.TTL, result.TTL)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
})
t.Run("update only TTL - no validation", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, existingRecord)
require.NoError(t, err)
updatedRecord := &records.Record{
ID: existingRecord.ID,
Name: existingRecord.Name,
Type: existingRecord.Type,
Content: existingRecord.Content,
TTL: 600, // Only TTL changed
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
// Event should be stored
}
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, 600, result.TTL)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedRecord := &records.Record{
ID: testRecordID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
updatedRecord := &records.Record{
ID: "non-existent-record",
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.100",
TTL: 600,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
})
t.Run("update creates duplicate", func(t *testing.T) {
manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record1)
require.NoError(t, err)
record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300)
err = testStore.CreateDNSRecord(ctx, record2)
require.NoError(t, err)
updatedRecord := &records.Record{
ID: record2.ID,
Name: "api.example.com",
Type: records.RecordTypeA,
Content: "192.168.1.1",
TTL: 300,
}
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "identical record already exists")
})
}
func TestManagerImpl_DeleteRecord(t *testing.T) {
ctx := context.Background()
t.Run("success", func(t *testing.T) {
manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300)
err := testStore.CreateDNSRecord(ctx, record)
require.NoError(t, err)
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
storeEventCalled = true
assert.Equal(t, testUserID, initiatorID)
assert.Equal(t, record.ID, targetID)
assert.Equal(t, testAccountID, accountID)
assert.Equal(t, activity.DNSRecordDeleted, activityID)
}
err = manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err)
assert.True(t, storeEventCalled, "StoreEvent should have been called")
_, err = testStore.GetDNSRecordByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID, record.ID)
require.Error(t, err)
})
t.Run("permission denied", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
s, ok := status.FromError(err)
assert.True(t, ok)
assert.Equal(t, status.PermissionDenied, s.Type())
})
t.Run("record not found", func(t *testing.T) {
manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t)
defer cleanup()
defer ctrl.Finish()
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
require.Error(t, err)
})
}

View File

@@ -0,0 +1,129 @@
package records
import (
"errors"
"net"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
type RecordType string
const (
RecordTypeA RecordType = "A"
RecordTypeAAAA RecordType = "AAAA"
RecordTypeCNAME RecordType = "CNAME"
)
type Record struct {
AccountID string `gorm:"index"`
ZoneID string `gorm:"index"`
ID string `gorm:"primaryKey"`
Name string
Type RecordType
Content string
TTL int
}
func NewRecord(accountID, zoneID, name string, recordType RecordType, content string, ttl int) *Record {
return &Record{
ID: xid.New().String(),
AccountID: accountID,
ZoneID: zoneID,
Name: name,
Type: recordType,
Content: content,
TTL: ttl,
}
}
func (r *Record) ToAPIResponse() *api.DNSRecord {
recordType := api.DNSRecordType(r.Type)
return &api.DNSRecord{
Id: r.ID,
Name: r.Name,
Type: recordType,
Content: r.Content,
Ttl: r.TTL,
}
}
func (r *Record) FromAPIRequest(req *api.DNSRecordRequest) {
r.Name = req.Name
r.Type = RecordType(req.Type)
r.Content = req.Content
r.TTL = req.Ttl
}
func (r *Record) Validate() error {
if r.Name == "" {
return errors.New("record name is required")
}
if !util.IsValidDomain(r.Name) {
return errors.New("invalid record name format")
}
if r.Type == "" {
return errors.New("record type is required")
}
switch r.Type {
case RecordTypeA:
if err := validateIPv4(r.Content); err != nil {
return err
}
case RecordTypeAAAA:
if err := validateIPv6(r.Content); err != nil {
return err
}
case RecordTypeCNAME:
if !util.IsValidDomain(r.Content) {
return errors.New("invalid CNAME record format")
}
default:
return errors.New("invalid record type, must be A, AAAA, or CNAME")
}
if r.TTL < 0 {
return errors.New("TTL cannot be negative")
}
return nil
}
func (r *Record) EventMeta(zoneID, zoneName string) map[string]any {
return map[string]any{
"name": r.Name,
"type": string(r.Type),
"content": r.Content,
"ttl": r.TTL,
"zone_id": zoneID,
"zone_name": zoneName,
}
}
func validateIPv4(content string) error {
if content == "" {
return errors.New("A record is required") //nolint:staticcheck
}
ip := net.ParseIP(content)
if ip == nil || ip.To4() == nil {
return errors.New("A record must be a valid IPv4 address") //nolint:staticcheck
}
return nil
}
func validateIPv6(content string) error {
if content == "" {
return errors.New("AAAA record is required")
}
ip := net.ParseIP(content)
if ip == nil || ip.To4() != nil {
return errors.New("AAAA record must be a valid IPv6 address")
}
return nil
}

View File

@@ -0,0 +1,89 @@
package zones
import (
"errors"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
type Zone struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
Name string
Domain string
Enabled bool
EnableSearchDomain bool
DistributionGroups []string `gorm:"serializer:json"`
Records []*records.Record `gorm:"foreignKey:ZoneID;references:ID"`
}
func NewZone(accountID, name, domain string, enabled, enableSearchDomain bool, distributionGroups []string) *Zone {
return &Zone{
ID: xid.New().String(),
AccountID: accountID,
Name: name,
Domain: domain,
Enabled: enabled,
EnableSearchDomain: enableSearchDomain,
DistributionGroups: distributionGroups,
}
}
func (z *Zone) ToAPIResponse() *api.Zone {
apiRecords := make([]api.DNSRecord, 0, len(z.Records))
for _, record := range z.Records {
if apiRecord := record.ToAPIResponse(); apiRecord != nil {
apiRecords = append(apiRecords, *apiRecord)
}
}
return &api.Zone{
DistributionGroups: z.DistributionGroups,
Domain: z.Domain,
EnableSearchDomain: z.EnableSearchDomain,
Enabled: z.Enabled,
Id: z.ID,
Name: z.Name,
Records: apiRecords,
}
}
func (z *Zone) FromAPIRequest(req *api.ZoneRequest) {
z.Name = req.Name
z.Domain = req.Domain
z.EnableSearchDomain = req.EnableSearchDomain
z.DistributionGroups = req.DistributionGroups
enabled := true
if req.Enabled != nil {
enabled = *req.Enabled
}
z.Enabled = enabled
}
func (z *Zone) Validate() error {
if z.Name == "" {
return errors.New("zone name is required")
}
if len(z.Name) > 255 {
return errors.New("zone name exceeds maximum length of 255 characters")
}
if !util.IsValidDomain(z.Domain) {
return errors.New("invalid zone domain format")
}
if len(z.DistributionGroups) == 0 {
return errors.New("at least one distribution group is required")
}
return nil
}
func (z *Zone) EventMeta() map[string]any {
return map[string]any{"name": z.Name, "domain": z.Domain}
}

View File

@@ -92,7 +92,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController(), s.IdpManager())
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}

View File

@@ -8,6 +8,10 @@ import (
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/zones"
zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/geolocation"
@@ -158,3 +162,15 @@ func (s *BaseServer) NetworksManager() networks.Manager {
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
})
}
func (s *BaseServer) ZonesManager() zones.Manager {
return Create(s, func() zones.Manager {
return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain())
})
}
func (s *BaseServer) RecordsManager() records.Manager {
return Create(s, func() records.Manager {
return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager())
})
}

View File

@@ -374,9 +374,10 @@ func shouldUsePortRange(rule *proto.FirewallRule) bool {
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
NonAuthoritative: zone.NonAuthoritative,
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
SearchDomainDisabled: zone.SearchDomainDisabled,
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{