mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-03 23:56:38 +00:00
[management] Add custom dns zones (#4849)
This commit is contained in:
@@ -22,6 +22,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -4025,3 +4027,476 @@ func TestSqlStore_ExecuteInTransaction_Timeout(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "transaction has already been committed or rolled back", "expected transaction rolled back error, got: %v", err)
|
||||
}
|
||||
|
||||
func TestSqlStore_CreateZone(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedZone)
|
||||
assert.Equal(t, zone.ID, savedZone.ID)
|
||||
assert.Equal(t, zone.Name, savedZone.Name)
|
||||
assert.Equal(t, zone.Domain, savedZone.Domain)
|
||||
assert.Equal(t, zone.Enabled, savedZone.Enabled)
|
||||
assert.Equal(t, zone.EnableSearchDomain, savedZone.EnableSearchDomain)
|
||||
assert.Equal(t, zone.DistributionGroups, savedZone.DistributionGroups)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetZoneByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
zoneID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing zone",
|
||||
accountID: accountID,
|
||||
zoneID: zone.ID,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing zone",
|
||||
accountID: accountID,
|
||||
zoneID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty zone ID",
|
||||
accountID: accountID,
|
||||
zoneID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, savedZone)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedZone)
|
||||
assert.Equal(t, tt.zoneID, savedZone.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountZones(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone1 := zones.NewZone(accountID, "Zone 1", "example1.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone1)
|
||||
require.NoError(t, err)
|
||||
|
||||
zone2 := zones.NewZone(accountID, "Zone 2", "example2.com", true, true, []string{"group1", "group2"})
|
||||
err = store.CreateZone(context.Background(), zone2)
|
||||
require.NoError(t, err)
|
||||
|
||||
allZones, err := store.GetAccountZones(context.Background(), LockingStrengthNone, accountID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, allZones)
|
||||
assert.GreaterOrEqual(t, len(allZones), 2)
|
||||
|
||||
zoneIDs := make(map[string]bool)
|
||||
for _, z := range allZones {
|
||||
zoneIDs[z.ID] = true
|
||||
}
|
||||
assert.True(t, zoneIDs[zone1.ID])
|
||||
assert.True(t, zoneIDs[zone2.ID])
|
||||
}
|
||||
|
||||
func TestSqlStore_GetZoneByDomain(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
otherAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3c"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
domain string
|
||||
expectError bool
|
||||
errorType status.Type
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing zone by domain",
|
||||
accountID: accountID,
|
||||
domain: "example.com",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing zone domain",
|
||||
accountID: accountID,
|
||||
domain: "non-existing.com",
|
||||
expectError: true,
|
||||
errorType: status.NotFound,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty domain",
|
||||
accountID: accountID,
|
||||
domain: "",
|
||||
expectError: true,
|
||||
errorType: status.NotFound,
|
||||
},
|
||||
{
|
||||
name: "retrieve with different account ID",
|
||||
accountID: otherAccountID,
|
||||
domain: "example.com",
|
||||
expectError: true,
|
||||
errorType: status.NotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
savedZone, err := store.GetZoneByDomain(context.Background(), tt.accountID, tt.domain)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, tt.errorType, sErr.Type())
|
||||
require.Nil(t, savedZone)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedZone)
|
||||
assert.Equal(t, tt.domain, savedZone.Domain)
|
||||
assert.Equal(t, zone.ID, savedZone.ID)
|
||||
assert.Equal(t, zone.Name, savedZone.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateZone(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
zone.Name = "Updated Zone"
|
||||
zone.Domain = "updated.com"
|
||||
zone.Enabled = false
|
||||
zone.EnableSearchDomain = true
|
||||
zone.DistributionGroups = []string{"group2", "group3"}
|
||||
|
||||
err = store.UpdateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedZone)
|
||||
assert.Equal(t, "Updated Zone", updatedZone.Name)
|
||||
assert.Equal(t, "updated.com", updatedZone.Domain)
|
||||
assert.False(t, updatedZone.Enabled)
|
||||
assert.True(t, updatedZone.EnableSearchDomain)
|
||||
assert.Equal(t, []string{"group2", "group3"}, updatedZone.DistributionGroups)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteZone(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.DeleteZone(context.Background(), accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
deletedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, deletedZone)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
}
|
||||
|
||||
func TestSqlStore_CreateDNSRecord(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
|
||||
err = store.CreateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedRecord)
|
||||
assert.Equal(t, record.ID, savedRecord.ID)
|
||||
assert.Equal(t, record.Name, savedRecord.Name)
|
||||
assert.Equal(t, record.Type, savedRecord.Type)
|
||||
assert.Equal(t, record.Content, savedRecord.Content)
|
||||
assert.Equal(t, record.TTL, savedRecord.TTL)
|
||||
assert.Equal(t, zone.ID, savedRecord.ZoneID)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetDNSRecordByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
zoneID string
|
||||
recordID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing record",
|
||||
accountID: accountID,
|
||||
zoneID: zone.ID,
|
||||
recordID: record.ID,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing record",
|
||||
accountID: accountID,
|
||||
zoneID: zone.ID,
|
||||
recordID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty record ID",
|
||||
accountID: accountID,
|
||||
zoneID: zone.ID,
|
||||
recordID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID, tt.recordID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, savedRecord)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, savedRecord)
|
||||
assert.Equal(t, tt.recordID, savedRecord.ID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetZoneDNSRecords(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordA := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), recordA)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordAAAA := records.NewRecord(accountID, zone.ID, "ipv6.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), recordAAAA)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordCNAME := records.NewRecord(accountID, zone.ID, "alias.example.com", records.RecordTypeCNAME, "www.example.com", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), recordCNAME)
|
||||
require.NoError(t, err)
|
||||
|
||||
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, allRecords)
|
||||
assert.Equal(t, 3, len(allRecords))
|
||||
|
||||
recordIDs := make(map[string]bool)
|
||||
for _, r := range allRecords {
|
||||
recordIDs[r.ID] = true
|
||||
}
|
||||
assert.True(t, recordIDs[recordA.ID])
|
||||
assert.True(t, recordIDs[recordAAAA.ID])
|
||||
assert.True(t, recordIDs[recordCNAME.ID])
|
||||
}
|
||||
|
||||
func TestSqlStore_GetZoneDNSRecordsByName(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeAAAA, "2001:db8::1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
record3 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
|
||||
err = store.CreateDNSRecord(context.Background(), record3)
|
||||
require.NoError(t, err)
|
||||
|
||||
recordsByName, err := store.GetZoneDNSRecordsByName(context.Background(), LockingStrengthNone, accountID, zone.ID, "www.example.com")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, recordsByName)
|
||||
assert.Equal(t, 2, len(recordsByName))
|
||||
|
||||
for _, r := range recordsByName {
|
||||
assert.Equal(t, "www.example.com", r.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_UpdateDNSRecord(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
record.Name = "api.example.com"
|
||||
record.Content = "192.168.1.100"
|
||||
record.TTL = 600
|
||||
|
||||
err = store.UpdateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updatedRecord)
|
||||
assert.Equal(t, "api.example.com", updatedRecord.Name)
|
||||
assert.Equal(t, "192.168.1.100", updatedRecord.Content)
|
||||
assert.Equal(t, 600, updatedRecord.TTL)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteDNSRecord(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.DeleteDNSRecord(context.Background(), accountID, zone.ID, record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
deletedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, deletedRecord)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"})
|
||||
err = store.CreateZone(context.Background(), zone)
|
||||
require.NoError(t, err)
|
||||
|
||||
record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300)
|
||||
err = store.CreateDNSRecord(context.Background(), record1)
|
||||
require.NoError(t, err)
|
||||
|
||||
record2 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600)
|
||||
err = store.CreateDNSRecord(context.Background(), record2)
|
||||
require.NoError(t, err)
|
||||
|
||||
allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(allRecords))
|
||||
|
||||
err = store.DeleteZoneDNSRecords(context.Background(), accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
remainingRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, len(remainingRecords))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user