[management] Add custom dns zones (#4849)

This commit is contained in:
Bethuel Mmbaga
2026-01-16 10:12:05 +01:00
committed by GitHub
parent 291e640b28
commit 067c77e49e
36 changed files with 4837 additions and 63 deletions

View File

@@ -22,6 +22,8 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -4025,3 +4027,476 @@ func TestSqlStore_ExecuteInTransaction_Timeout(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "transaction has already been committed or rolled back", "expected transaction rolled back error, got: %v", err)
}
func TestSqlStore_CreateZone(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
require.NotNil(t, savedZone)
assert.Equal(t, zone.ID, savedZone.ID)
assert.Equal(t, zone.Name, savedZone.Name)
assert.Equal(t, zone.Domain, savedZone.Domain)
assert.Equal(t, zone.Enabled, savedZone.Enabled)
assert.Equal(t, zone.EnableSearchDomain, savedZone.EnableSearchDomain)
assert.Equal(t, zone.DistributionGroups, savedZone.DistributionGroups)
}
func TestSqlStore_GetZoneByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
tests := []struct {
name string
accountID string
zoneID string
expectError bool
}{
{
name: "retrieve existing zone",
accountID: accountID,
zoneID: zone.ID,
expectError: false,
},
{
name: "retrieve non-existing zone",
accountID: accountID,
zoneID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty zone ID",
accountID: accountID,
zoneID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, savedZone)
} else {
require.NoError(t, err)
require.NotNil(t, savedZone)
assert.Equal(t, tt.zoneID, savedZone.ID)
}
})
}
}
func TestSqlStore_GetAccountZones(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone1 := zones.NewZone(accountID, "Zone 1", "example1.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone1)
require.NoError(t, err)
zone2 := zones.NewZone(accountID, "Zone 2", "example2.com", true, true, []string{"group1", "group2"})
err = store.CreateZone(context.Background(), zone2)
require.NoError(t, err)
allZones, err := store.GetAccountZones(context.Background(), LockingStrengthNone, accountID)
require.NoError(t, err)
require.NotNil(t, allZones)
assert.GreaterOrEqual(t, len(allZones), 2)
zoneIDs := make(map[string]bool)
for _, z := range allZones {
zoneIDs[z.ID] = true
}
assert.True(t, zoneIDs[zone1.ID])
assert.True(t, zoneIDs[zone2.ID])
}
func TestSqlStore_GetZoneByDomain(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
otherAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3c"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
tests := []struct {
name string
accountID string
domain string
expectError bool
errorType status.Type
}{
{
name: "retrieve existing zone by domain",
accountID: accountID,
domain: "example.com",
expectError: false,
},
{
name: "retrieve non-existing zone domain",
accountID: accountID,
domain: "non-existing.com",
expectError: true,
errorType: status.NotFound,
},
{
name: "retrieve with empty domain",
accountID: accountID,
domain: "",
expectError: true,
errorType: status.NotFound,
},
{
name: "retrieve with different account ID",
accountID: otherAccountID,
domain: "example.com",
expectError: true,
errorType: status.NotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
savedZone, err := store.GetZoneByDomain(context.Background(), tt.accountID, tt.domain)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, tt.errorType, sErr.Type())
require.Nil(t, savedZone)
} else {
require.NoError(t, err)
require.NotNil(t, savedZone)
assert.Equal(t, tt.domain, savedZone.Domain)
assert.Equal(t, zone.ID, savedZone.ID)
assert.Equal(t, zone.Name, savedZone.Name)
}
})
}
}
func TestSqlStore_UpdateZone(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
zone.Name = "Updated Zone"
zone.Domain = "updated.com"
zone.Enabled = false
zone.EnableSearchDomain = true
zone.DistributionGroups = []string{"group2", "group3"}
err = store.UpdateZone(context.Background(), zone)
require.NoError(t, err)
updatedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
require.NotNil(t, updatedZone)
assert.Equal(t, "Updated Zone", updatedZone.Name)
assert.Equal(t, "updated.com", updatedZone.Domain)
assert.False(t, updatedZone.Enabled)
assert.True(t, updatedZone.EnableSearchDomain)
assert.Equal(t, []string{"group2", "group3"}, updatedZone.DistributionGroups)
}
func TestSqlStore_DeleteZone(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
err = store.DeleteZone(context.Background(), accountID, zone.ID)
require.NoError(t, err)
deletedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.Error(t, err)
require.Nil(t, deletedZone)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
}
func TestSqlStore_CreateDNSRecord(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record)
require.NoError(t, err)
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
require.NoError(t, err)
require.NotNil(t, savedRecord)
assert.Equal(t, record.ID, savedRecord.ID)
assert.Equal(t, record.Name, savedRecord.Name)
assert.Equal(t, record.Type, savedRecord.Type)
assert.Equal(t, record.Content, savedRecord.Content)
assert.Equal(t, record.TTL, savedRecord.TTL)
assert.Equal(t, zone.ID, savedRecord.ZoneID)
}
func TestSqlStore_GetDNSRecordByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record)
require.NoError(t, err)
tests := []struct {
name string
accountID string
zoneID string
recordID string
expectError bool
}{
{
name: "retrieve existing record",
accountID: accountID,
zoneID: zone.ID,
recordID: record.ID,
expectError: false,
},
{
name: "retrieve non-existing record",
accountID: accountID,
zoneID: zone.ID,
recordID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty record ID",
accountID: accountID,
zoneID: zone.ID,
recordID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID, tt.recordID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, savedRecord)
} else {
require.NoError(t, err)
require.NotNil(t, savedRecord)
assert.Equal(t, tt.recordID, savedRecord.ID)
}
})
}
}
func TestSqlStore_GetZoneDNSRecords(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
recordA := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), recordA)
require.NoError(t, err)
recordAAAA := records.NewRecord(accountID, zone.ID, "ipv6.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
err = store.CreateDNSRecord(context.Background(), recordAAAA)
require.NoError(t, err)
recordCNAME := records.NewRecord(accountID, zone.ID, "alias.example.com", records.RecordTypeCNAME, "www.example.com", 300)
err = store.CreateDNSRecord(context.Background(), recordCNAME)
require.NoError(t, err)
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
require.NotNil(t, allRecords)
assert.Equal(t, 3, len(allRecords))
recordIDs := make(map[string]bool)
for _, r := range allRecords {
recordIDs[r.ID] = true
}
assert.True(t, recordIDs[recordA.ID])
assert.True(t, recordIDs[recordAAAA.ID])
assert.True(t, recordIDs[recordCNAME.ID])
}
func TestSqlStore_GetZoneDNSRecordsByName(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record1)
require.NoError(t, err)
record2 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
err = store.CreateDNSRecord(context.Background(), record2)
require.NoError(t, err)
record3 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
err = store.CreateDNSRecord(context.Background(), record3)
require.NoError(t, err)
recordsByName, err := store.GetZoneDNSRecordsByName(context.Background(), LockingStrengthNone, accountID, zone.ID, "www.example.com")
require.NoError(t, err)
require.NotNil(t, recordsByName)
assert.Equal(t, 2, len(recordsByName))
for _, r := range recordsByName {
assert.Equal(t, "www.example.com", r.Name)
}
}
func TestSqlStore_UpdateDNSRecord(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record)
require.NoError(t, err)
record.Name = "api.example.com"
record.Content = "192.168.1.100"
record.TTL = 600
err = store.UpdateDNSRecord(context.Background(), record)
require.NoError(t, err)
updatedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
require.NoError(t, err)
require.NotNil(t, updatedRecord)
assert.Equal(t, "api.example.com", updatedRecord.Name)
assert.Equal(t, "192.168.1.100", updatedRecord.Content)
assert.Equal(t, 600, updatedRecord.TTL)
}
func TestSqlStore_DeleteDNSRecord(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record)
require.NoError(t, err)
err = store.DeleteDNSRecord(context.Background(), accountID, zone.ID, record.ID)
require.NoError(t, err)
deletedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
require.Error(t, err)
require.Nil(t, deletedRecord)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
}
func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
err = store.CreateZone(context.Background(), zone)
require.NoError(t, err)
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
err = store.CreateDNSRecord(context.Background(), record1)
require.NoError(t, err)
record2 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
err = store.CreateDNSRecord(context.Background(), record2)
require.NoError(t, err)
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
assert.Equal(t, 2, len(allRecords))
err = store.DeleteZoneDNSRecords(context.Background(), accountID, zone.ID)
require.NoError(t, err)
remainingRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
require.NoError(t, err)
assert.Equal(t, 0, len(remainingRecords))
}