diff --git a/management/server/account.go b/management/server/account.go index 614292f46..acbbd1879 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -89,6 +89,8 @@ type AccountManager interface { ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) GetDNSDomain() string GetEvents(accountID, userID string) ([]*activity.Event, error) + GetDNSSettings(accountID string, userID string) (*DNSSettings, error) + SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error } type DefaultAccountManager struct { @@ -129,6 +131,7 @@ type Account struct { Rules map[string]*Rule Routes map[string]*route.Route NameServerGroups map[string]*nbdns.NameServerGroup + DNSSettings *DNSSettings } type UserInfo struct { @@ -339,6 +342,21 @@ func (a *Account) getUserGroups(userID string) ([]string, error) { return user.AutoGroups, nil } +func (a *Account) getPeerDNSManagementStatus(peerID string) bool { + peerGroups := a.getPeerGroups(peerID) + enabled := true + if a.DNSSettings != nil { + for _, groupID := range a.DNSSettings.DisabledManagementGroups { + _, found := peerGroups[groupID] + if found { + enabled = false + break + } + } + } + return enabled +} + func (a *Account) getPeerGroups(peerID string) lookupMap { groupList := make(lookupMap) for groupID, group := range a.Groups { @@ -415,6 +433,11 @@ func (a *Account) Copy() *Account { nsGroups[id] = nsGroup.Copy() } + var dnsSettings *DNSSettings + if a.DNSSettings != nil { + dnsSettings = a.DNSSettings + } + return &Account{ Id: a.Id, CreatedBy: a.CreatedBy, @@ -429,6 +452,7 @@ func (a *Account) Copy() *Account { Rules: rules, Routes: routes, NameServerGroups: nsGroups, + DNSSettings: dnsSettings, } } @@ -1042,6 +1066,9 @@ func newAccountWithId(accountId, userId, domain string) *Account { routes := make(map[string]*route.Route) nameServersGroups := make(map[string]*nbdns.NameServerGroup) users[userId] = NewAdminUser(userId) + dnsSettings := &DNSSettings{ + DisabledManagementGroups: make([]string, 0), + } log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key) acc := &Account{ @@ -1054,6 +1081,7 @@ func newAccountWithId(accountId, userId, domain string) *Account { Domain: domain, Routes: routes, NameServerGroups: nameServersGroups, + DNSSettings: dnsSettings, } addAllGroup(acc) diff --git a/management/server/account_test.go b/management/server/account_test.go index 353fe7528..00cc3403d 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1276,6 +1276,7 @@ func TestAccount_Copy(t *testing.T) { ID: "nsGroup1", }, }, + DNSSettings: &DNSSettings{}, } err := hasNilField(account) if err != nil { diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 6ebe0bc2e..1cdecd933 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -45,6 +45,10 @@ const ( GroupAddedToSetupKey // GroupRemovedFromSetupKey indicates that a user removed a group from a setup key GroupRemovedFromSetupKey + // GroupAddedToDisabledManagementGroups indicates that a user added a group to the DNS setting Disabled management groups + GroupAddedToDisabledManagementGroups + // GroupRemovedFromDisabledManagementGroups indicates that a user removed a group from the DNS setting Disabled management groups + GroupRemovedFromDisabledManagementGroups ) const ( @@ -92,6 +96,10 @@ const ( GroupAddedToSetupKeyMessage string = "Group added to setup key" // GroupRemovedFromSetupKeyMessage is a human-readable text message of the GroupRemovedFromSetupKey activity GroupRemovedFromSetupKeyMessage string = "Group removed from user setup key" + // GroupAddedToDisabledManagementGroupsMessage is a human-readable text message of the GroupAddedToDisabledManagementGroups activity + GroupAddedToDisabledManagementGroupsMessage + // GroupRemovedFromDisabledManagementGroupsMessage is a human-readable text message of the GroupRemovedFromDisabledManagementGroups activity + GroupRemovedFromDisabledManagementGroupsMessage ) // Activity that triggered an Event @@ -144,6 +152,10 @@ func (a Activity) Message() string { return GroupAddedToSetupKeyMessage case GroupRemovedFromSetupKey: return GroupRemovedFromSetupKeyMessage + case GroupAddedToDisabledManagementGroups: + return GroupAddedToDisabledManagementGroupsMessage + case GroupRemovedFromDisabledManagementGroups: + return GroupRemovedFromDisabledManagementGroupsMessage default: return "UNKNOWN_ACTIVITY" } @@ -196,6 +208,10 @@ func (a Activity) StringCode() string { return "setupkey.group.add" case GroupRemovedFromSetupKey: return "setupkey.group.delete" + case GroupAddedToDisabledManagementGroups: + return "dns.setting.disabled.management.group.add" + case GroupRemovedFromDisabledManagementGroups: + return "dns.setting.disabled.management.group.delete" default: return "UNKNOWN_ACTIVITY" } diff --git a/management/server/dns.go b/management/server/dns.go index 85c4e481f..1c0499725 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -5,13 +5,123 @@ import ( "github.com/miekg/dns" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/status" log "github.com/sirupsen/logrus" "strconv" ) +const defaultTTL = 300 + type lookupMap map[string]struct{} -const defaultTTL = 300 +// DNSSettings defines dns settings at the account level +type DNSSettings struct { + // DisabledManagementGroups groups whose DNS management is disabled + DisabledManagementGroups []string +} + +// Copy returns a copy of the DNS settings +func (d *DNSSettings) Copy() *DNSSettings { + settings := &DNSSettings{ + DisabledManagementGroups: make([]string, 0), + } + + if d == nil { + return settings + } + + if d.DisabledManagementGroups != nil && len(d.DisabledManagementGroups) > 0 { + settings.DisabledManagementGroups = d.DisabledManagementGroups[:] + } + + return settings +} + +// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID +func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, err + } + + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + + if !user.IsAdmin() { + return nil, status.Errorf(status.PermissionDenied, "only admins are allowed to view DNS settings") + } + + if account.DNSSettings == nil { + return &DNSSettings{}, nil + } + + return account.DNSSettings.Copy(), nil +} + +// SaveDNSSettings validates a user role and updates the account's DNS settings +func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error { + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return err + } + + user, err := account.FindUser(userID) + if err != nil { + return err + } + + if !user.IsAdmin() { + return status.Errorf(status.PermissionDenied, "only admins are allowed to update DNS settings") + } + + if dnsSettingsToSave == nil { + return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") + } + + err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups) + if err != nil { + return err + } + + oldSettings := &DNSSettings{} + if account.DNSSettings != nil { + oldSettings = account.DNSSettings.Copy() + } + + account.DNSSettings = dnsSettingsToSave.Copy() + + account.Network.IncSerial() + if err = am.Store.SaveAccount(account); err != nil { + return err + } + + go func() { + addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + for _, id := range addedGroups { + group := account.GetGroup(id) + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.storeEvent(userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + } + + removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + for _, id := range removedGroups { + group := account.GetGroup(id) + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.storeEvent(userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + } + }() + + return am.updateAccountPeers(account) +} func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig { protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable} diff --git a/management/server/dns_test.go b/management/server/dns_test.go new file mode 100644 index 000000000..d1d83269a --- /dev/null +++ b/management/server/dns_test.go @@ -0,0 +1,266 @@ +package server + +import ( + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/status" + "github.com/stretchr/testify/require" + "testing" +) + +const ( + dnsGroup1ID = "group1" + dnsGroup2ID = "group2" + dnsPeer1Key = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=" + dnsPeer2Key = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI=" + dnsAccountID = "testingAcc" + dnsAdminUserID = "testingAdminUser" + dnsRegularUserID = "testingRegularUser" +) + +func TestGetDNSSettings(t *testing.T) { + am, err := createDNSManager(t) + if err != nil { + t.Error("failed to create account manager") + } + + account, err := initTestDNSAccount(t, am) + if err != nil { + t.Error("failed to init testing account") + } + + dnsSettings, err := am.GetDNSSettings(account.Id, dnsAdminUserID) + if err != nil { + t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) + } + + if dnsSettings == nil { + t.Fatal("DNS settings for new accounts shouldn't return nil") + } + + account.DNSSettings = &DNSSettings{ + DisabledManagementGroups: []string{group1ID}, + } + + err = am.Store.SaveAccount(account) + if err != nil { + t.Error("failed to save testing account with new DNS settings") + } + + dnsSettings, err = am.GetDNSSettings(account.Id, dnsAdminUserID) + if err != nil { + t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) + } + + if len(dnsSettings.DisabledManagementGroups) != 1 { + t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) + } + + _, err = am.GetDNSSettings(account.Id, dnsRegularUserID) + if err == nil { + t.Errorf("An error should be returned when getting the DNS settings with a regular user") + } + + s, ok := status.FromError(err) + if !ok && s.Type() != status.PermissionDenied { + t.Errorf("returned error should be Permission Denied, got err: %s", err) + } +} + +func TestSaveDNSSettings(t *testing.T) { + testCases := []struct { + name string + userID string + inputSettings *DNSSettings + shouldFail bool + }{ + { + name: "Saving As Admin Should Be OK", + userID: dnsAdminUserID, + inputSettings: &DNSSettings{ + DisabledManagementGroups: []string{dnsGroup1ID}, + }, + }, + { + name: "Should Not Update Settings As Regular User", + userID: dnsRegularUserID, + inputSettings: &DNSSettings{ + DisabledManagementGroups: []string{dnsGroup1ID}, + }, + shouldFail: true, + }, + { + name: "Should Not Update Settings If Input is Nil", + userID: dnsAdminUserID, + inputSettings: nil, + shouldFail: true, + }, + { + name: "Should Not Update Settings If Group Is Invalid", + userID: dnsAdminUserID, + inputSettings: &DNSSettings{ + DisabledManagementGroups: []string{"non-existing-group"}, + }, + shouldFail: true, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + am, err := createDNSManager(t) + if err != nil { + t.Error("failed to create account manager") + } + + account, err := initTestDNSAccount(t, am) + if err != nil { + t.Error("failed to init testing account") + } + + err = am.SaveDNSSettings(account.Id, testCase.userID, testCase.inputSettings) + if err != nil { + if testCase.shouldFail { + return + } + t.Error(err) + } + + updatedAccount, err := am.Store.GetAccount(account.Id) + if err != nil { + t.Errorf("should be able to retrieve updated account, got err: %s", err) + } + + require.ElementsMatchf(t, testCase.inputSettings.DisabledManagementGroups, updatedAccount.DNSSettings.DisabledManagementGroups, + "resulting DNS settings should match input") + + }) + } +} + +func TestGetNetworkMap_DNSConfigSync(t *testing.T) { + + am, err := createDNSManager(t) + if err != nil { + t.Error("failed to create account manager") + } + + account, err := initTestDNSAccount(t, am) + if err != nil { + t.Error("failed to init testing account") + } + + newAccountDNSConfig, err := am.GetNetworkMap(dnsPeer1Key) + require.NoError(t, err) + require.Len(t, newAccountDNSConfig.DNSConfig.CustomZones, 1, "default DNS config should have one custom zone for peers") + require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled") + + dnsSettings := account.DNSSettings.Copy() + dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID) + account.DNSSettings = dnsSettings + err = am.Store.SaveAccount(account) + require.NoError(t, err) + + updatedAccountDNSConfig, err := am.GetNetworkMap(dnsPeer1Key) + require.NoError(t, err) + require.Len(t, updatedAccountDNSConfig.DNSConfig.CustomZones, 0, "updated DNS config should have no custom zone when peer belongs to a disabled group") + require.False(t, updatedAccountDNSConfig.DNSConfig.ServiceEnable, "updated DNS config should have local DNS service disabled when peer belongs to a disabled group") + + peer2AccountDNSConfig, err := am.GetNetworkMap(dnsPeer2Key) + require.NoError(t, err) + require.Len(t, peer2AccountDNSConfig.DNSConfig.CustomZones, 1, "DNS config should have one custom zone for peers not in the disabled group") + require.True(t, peer2AccountDNSConfig.DNSConfig.ServiceEnable, "DNS config should have DNS service enabled for peers not in the disabled group") +} + +func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { + store, err := createDNSStore(t) + if err != nil { + return nil, err + } + eventStore := &activity.InMemoryEventStore{} + return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore) +} + +func createDNSStore(t *testing.T) (Store, error) { + dataDir := t.TempDir() + store, err := NewFileStore(dataDir) + if err != nil { + return nil, err + } + + return store, nil +} + +func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { + peer1 := &Peer{ + Key: dnsPeer1Key, + Name: "test-host1@netbird.io", + Meta: PeerSystemMeta{ + Hostname: "test-host1@netbird.io", + GoOS: "linux", + Kernel: "Linux", + Core: "21.04", + Platform: "x86_64", + OS: "Ubuntu", + WtVersion: "development", + UIVersion: "development", + }, + DNSLabel: dnsPeer1Key, + } + peer2 := &Peer{ + Key: dnsPeer2Key, + Name: "test-host2@netbird.io", + Meta: PeerSystemMeta{ + Hostname: "test-host2@netbird.io", + GoOS: "linux", + Kernel: "Linux", + Core: "21.04", + Platform: "x86_64", + OS: "Ubuntu", + WtVersion: "development", + UIVersion: "development", + }, + DNSLabel: dnsPeer2Key, + } + + domain := "example.com" + + account := newAccountWithId(dnsAccountID, dnsAdminUserID, domain) + + account.Users[dnsRegularUserID] = &User{ + Id: dnsRegularUserID, + Role: UserRoleUser, + } + + err := am.Store.SaveAccount(account) + if err != nil { + return nil, err + } + + newGroup1 := &Group{ + ID: dnsGroup1ID, + Peers: []string{peer1.Key}, + Name: dnsGroup1ID, + } + + newGroup2 := &Group{ + ID: dnsGroup2ID, + Name: dnsGroup2ID, + } + + account.Groups[newGroup1.ID] = newGroup1 + account.Groups[newGroup2.ID] = newGroup2 + + err = am.Store.SaveAccount(account) + if err != nil { + return nil, err + } + + _, err = am.AddPeer("", dnsAdminUserID, peer1) + if err != nil { + return nil, err + } + _, err = am.AddPeer("", dnsAdminUserID, peer2) + if err != nil { + return nil, err + } + + return account, nil +} diff --git a/management/server/event.go b/management/server/event.go index 3c52d94b3..df19b9ed4 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -3,6 +3,8 @@ package server import ( "fmt" "github.com/netbirdio/netbird/management/server/activity" + log "github.com/sirupsen/logrus" + "time" ) // GetEvents returns a list of activity events of an account @@ -31,3 +33,17 @@ func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activit return filtered, nil } + +func (am *DefaultAccountManager) storeEvent(initiatorID, targetID, accountID string, activityID activity.Activity, meta map[string]any) { + _, err := am.eventStore.Save(&activity.Event{ + Timestamp: time.Now(), + Activity: activityID, + InitiatorID: initiatorID, + TargetID: targetID, + AccountID: accountID, + Meta: meta, + }) + if err != nil { + log.Errorf("received an error while storing an activity event, error: %s", err) + } +} diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 6071d1788..79b8e9135 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -503,6 +503,16 @@ components: enum: [ "name", "description", "enabled", "groups", "nameservers", "primary", "domains" ] required: - path + DNSSettings: + type: object + properties: + disabled_management_groups: + description: Groups whose DNS management is disabled + type: array + items: + type: string + required: + - disabled_management_groups Event: type: object properties: @@ -522,10 +532,10 @@ components: enum: [ "user.peer.delete", "user.join", "user.invite", "user.peer.add", "user.group.add", "user.group.delete", "user.role.update", "setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse", - "setupkey.group.delete", "setupkey.group.add" + "setupkey.group.delete", "setupkey.group.add", "rule.add", "rule.delete", "rule.update", - "group.add", "group.update", - "account.create", + "group.add", "group.update", "dns.setting.disabled.management.group.add", + "account.create", "dns.setting.disabled.management.group.delete" ] initiator_id: description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event. @@ -1619,6 +1629,55 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + + /api/dns/settings: + get: + summary: Returns a DNS settings object + tags: [ DNS ] + security: + - BearerAuth: [ ] + responses: + '200': + description: A JSON Object of DNS Setting + content: + application/json: + schema: + items: + $ref: '#/components/schemas/DNSSettings' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Updates a DNS settings object + tags: [ DNS ] + security: + - BearerAuth: [ ] + requestBody: + description: A DNS settings object + content: + 'application/json': + schema: + $ref: '#/components/schemas/DNSSettings' + responses: + '200': + description: A JSON Object of DNS Setting + content: + application/json: + schema: + $ref: '#/components/schemas/DNSSettings' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/events: get: summary: Returns a list of all events diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 6fe25d1df..550d44731 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -13,24 +13,28 @@ const ( // Defines values for EventActivityCode. const ( - EventActivityCodeAccountCreate EventActivityCode = "account.create" - EventActivityCodeGroupAdd EventActivityCode = "group.add" - EventActivityCodeGroupUpdate EventActivityCode = "group.update" - EventActivityCodeRuleAdd EventActivityCode = "rule.add" - EventActivityCodeRuleDelete EventActivityCode = "rule.delete" - EventActivityCodeRuleUpdate EventActivityCode = "rule.update" - EventActivityCodeSetupkeyAdd EventActivityCode = "setupkey.add" - EventActivityCodeSetupkeyOveruse EventActivityCode = "setupkey.overuse" - EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add" - EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke" - EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update" - EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add" - EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete" - EventActivityCodeUserInvite EventActivityCode = "user.invite" - EventActivityCodeUserJoin EventActivityCode = "user.join" - EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add" - EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete" - EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update" + EventActivityCodeAccountCreate EventActivityCode = "account.create" + EventActivityCodeDnsSettingDisabledManagementGroupAdd EventActivityCode = "dns.setting.disabled.management.group.add" + EventActivityCodeDnsSettingDisabledManagementGroupDelete EventActivityCode = "dns.setting.disabled.management.group.delete" + EventActivityCodeGroupAdd EventActivityCode = "group.add" + EventActivityCodeGroupUpdate EventActivityCode = "group.update" + EventActivityCodeRuleAdd EventActivityCode = "rule.add" + EventActivityCodeRuleDelete EventActivityCode = "rule.delete" + EventActivityCodeRuleUpdate EventActivityCode = "rule.update" + EventActivityCodeSetupkeyAdd EventActivityCode = "setupkey.add" + EventActivityCodeSetupkeyGroupAdd EventActivityCode = "setupkey.group.add" + EventActivityCodeSetupkeyGroupDelete EventActivityCode = "setupkey.group.delete" + EventActivityCodeSetupkeyOveruse EventActivityCode = "setupkey.overuse" + EventActivityCodeSetupkeyPeerAdd EventActivityCode = "setupkey.peer.add" + EventActivityCodeSetupkeyRevoke EventActivityCode = "setupkey.revoke" + EventActivityCodeSetupkeyUpdate EventActivityCode = "setupkey.update" + EventActivityCodeUserGroupAdd EventActivityCode = "user.group.add" + EventActivityCodeUserGroupDelete EventActivityCode = "user.group.delete" + EventActivityCodeUserInvite EventActivityCode = "user.invite" + EventActivityCodeUserJoin EventActivityCode = "user.join" + EventActivityCodeUserPeerAdd EventActivityCode = "user.peer.add" + EventActivityCodeUserPeerDelete EventActivityCode = "user.peer.delete" + EventActivityCodeUserRoleUpdate EventActivityCode = "user.role.update" ) // Defines values for GroupPatchOperationOp. @@ -119,6 +123,12 @@ const ( UserStatusInvited UserStatus = "invited" ) +// DNSSettings defines model for DNSSettings. +type DNSSettings struct { + // DisabledManagementGroups Groups whose DNS management is disabled + DisabledManagementGroups []string `json:"disabled_management_groups"` +} + // Event defines model for Event. type Event struct { // Activity The activity that occurred during the event @@ -657,6 +667,9 @@ type PatchApiDnsNameserversIdJSONRequestBody = PatchApiDnsNameserversIdJSONBody // PutApiDnsNameserversIdJSONRequestBody defines body for PutApiDnsNameserversId for application/json ContentType. type PutApiDnsNameserversIdJSONRequestBody = NameserverGroupRequest +// PutApiDnsSettingsJSONRequestBody defines body for PutApiDnsSettings for application/json ContentType. +type PutApiDnsSettingsJSONRequestBody = DNSSettings + // PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType. type PostApiGroupsJSONRequestBody PostApiGroupsJSONBody diff --git a/management/server/http/dns_settings.go b/management/server/http/dns_settings.go new file mode 100644 index 000000000..92c5a8322 --- /dev/null +++ b/management/server/http/dns_settings.go @@ -0,0 +1,83 @@ +package http + +import ( + "encoding/json" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + log "github.com/sirupsen/logrus" + "net/http" +) + +// DNSSettings is a handler that returns the DNS settings of the account +type DNSSettings struct { + jwtExtractor jwtclaims.ClaimsExtractor + accountManager server.AccountManager + authAudience string +} + +// NewDNSSettings returns a new instance of DNSSettings handler +func NewDNSSettings(accountManager server.AccountManager, authAudience string) *DNSSettings { + return &DNSSettings{ + accountManager: accountManager, + authAudience: authAudience, + jwtExtractor: *jwtclaims.NewClaimsExtractor(nil), + } +} + +// GetDNSSettings returns the DNS settings for the account +func (h *DNSSettings) GetDNSSettings(w http.ResponseWriter, r *http.Request) { + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + log.Error(err) + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } + + dnsSettings, err := h.accountManager.GetDNSSettings(account.Id, user.Id) + if err != nil { + util.WriteError(err, w) + return + } + + apiDNSSettings := &api.DNSSettings{ + DisabledManagementGroups: dnsSettings.DisabledManagementGroups, + } + + util.WriteJSONObject(w, apiDNSSettings) +} + +// UpdateDNSSettings handles update to DNS settings of an account +func (h *DNSSettings) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { + claims := h.jwtExtractor.ExtractClaimsFromRequestContext(r, h.authAudience) + account, user, err := h.accountManager.GetAccountFromToken(claims) + if err != nil { + util.WriteError(err, w) + return + } + + var req api.PutApiDnsSettingsJSONRequestBody + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + updateDNSSettings := &server.DNSSettings{ + DisabledManagementGroups: req.DisabledManagementGroups, + } + + err = h.accountManager.SaveDNSSettings(account.Id, user.Id, updateDNSSettings) + if err != nil { + util.WriteError(err, w) + return + } + + resp := api.DNSSettings{ + DisabledManagementGroups: updateDNSSettings.DisabledManagementGroups, + } + + util.WriteJSONObject(w, &resp) +} diff --git a/management/server/http/dns_settings_test.go b/management/server/http/dns_settings_test.go new file mode 100644 index 000000000..58bec62f1 --- /dev/null +++ b/management/server/http/dns_settings_test.go @@ -0,0 +1,149 @@ +package http + +import ( + "bytes" + "encoding/json" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/status" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/mock_server" +) + +const ( + testDNSSettingsAccountID = "test_id" + testDNSSettingsExistingGroup = "test_group" + testDNSSettingsUserID = "test_user" +) + +var baseExistingDNSSettings = &server.DNSSettings{ + DisabledManagementGroups: []string{testDNSSettingsExistingGroup}, +} + +var testingDNSSettingsAccount = &server.Account{ + Id: testDNSSettingsAccountID, + Domain: "hotmail.com", + Users: map[string]*server.User{ + testDNSSettingsUserID: server.NewAdminUser("test_user"), + }, + DNSSettings: baseExistingDNSSettings, +} + +func initDNSSettingsTestData() *DNSSettings { + return &DNSSettings{ + accountManager: &mock_server.MockAccountManager{ + GetDNSSettingsFunc: func(accountID string, userID string) (*server.DNSSettings, error) { + return testingDNSSettingsAccount.DNSSettings, nil + }, + SaveDNSSettingsFunc: func(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { + if dnsSettingsToSave != nil { + return nil + } + return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") + }, + GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil + }, + }, + authAudience: "", + jwtExtractor: jwtclaims.ClaimsExtractor{ + ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { + return jwtclaims.AuthorizationClaims{ + UserId: "test_user", + Domain: "hotmail.com", + AccountId: testDNSSettingsAccountID, + } + }, + }, + } +} + +func TestDNSSettingsHandlers(t *testing.T) { + tt := []struct { + name string + expectedStatus int + expectedBody bool + expectedDNSSettings *api.DNSSettings + requestType string + requestPath string + requestBody io.Reader + }{ + { + name: "Get DNS Settings", + requestType: http.MethodGet, + requestPath: "/api/dns/settings", + expectedStatus: http.StatusOK, + expectedBody: true, + expectedDNSSettings: &api.DNSSettings{ + DisabledManagementGroups: baseExistingDNSSettings.DisabledManagementGroups, + }, + }, + { + name: "Update DNS Settings", + requestType: http.MethodPut, + requestPath: "/api/dns/settings", + requestBody: bytes.NewBuffer( + []byte("{\"disabled_management_groups\":[\"group1\",\"group2\"]}")), + expectedStatus: http.StatusOK, + expectedBody: true, + expectedDNSSettings: &api.DNSSettings{ + DisabledManagementGroups: []string{"group1", "group2"}, + }, + }, + { + name: "Update DNS Settings Empty Body", + requestType: http.MethodPut, + requestPath: "/api/dns/settings", + requestBody: bytes.NewBuffer( + []byte("{}")), + expectedStatus: http.StatusOK, + expectedBody: true, + expectedDNSSettings: &api.DNSSettings{}, + }, + } + + p := initDNSSettingsTestData() + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + + router := mux.NewRouter() + router.HandleFunc("/api/dns/settings", p.GetDNSSettings).Methods("GET") + router.HandleFunc("/api/dns/settings", p.UpdateDNSSettings).Methods("PUT") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("I don't know what I expected; %v", err) + } + + if status := recorder.Code; status != tc.expectedStatus { + t.Errorf("handler returned wrong status code: got %v want %v, content: %s", + status, tc.expectedStatus, string(content)) + return + } + + if !tc.expectedBody { + return + } + + got := &api.DNSSettings{} + if err = json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + assert.Equal(t, tc.expectedDNSSettings, got) + }) + } +} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index d4beaec90..5015d5650 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -41,6 +41,7 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience routesHandler := NewRoutes(accountManager, authAudience) nameserversHandler := NewNameservers(accountManager, authAudience) eventsHandler := NewEvents(accountManager, authAudience) + dnsSettingsHandler := NewDNSSettings(accountManager, authAudience) apiHandler.HandleFunc("/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/peers/{id}", peersHandler.HandlePeer). @@ -84,6 +85,9 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience apiHandler.HandleFunc("/events", eventsHandler.GetEvents).Methods("GET", "OPTIONS") + apiHandler.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS") + apiHandler.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS") + err = apiHandler.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { methods, err := route.GetMethods() if err != nil { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 2e8249573..7edefd1ca 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -64,6 +64,8 @@ type MockAccountManager struct { GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) GetDNSDomainFunc func() string GetEventsFunc func(accountID, userID string) ([]*activity.Event, error) + GetDNSSettingsFunc func(accountID string, userID string) (*server.DNSSettings, error) + SaveDNSSettingsFunc func(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface @@ -499,3 +501,19 @@ func (am *MockAccountManager) GetEvents(accountID, userID string) ([]*activity.E } return nil, status.Errorf(codes.Unimplemented, "method GetEvents is not implemented") } + +// GetDNSSettings mocks GetDNSSettings of the AccountManager interface +func (am *MockAccountManager) GetDNSSettings(accountID string, userID string) (*server.DNSSettings, error) { + if am.GetDNSSettingsFunc != nil { + return am.GetDNSSettingsFunc(accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetDNSSettings is not implemented") +} + +// SaveDNSSettings mocks SaveDNSSettings of the AccountManager interface +func (am *MockAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { + if am.SaveDNSSettingsFunc != nil { + return am.SaveDNSSettingsFunc(accountID, userID, dnsSettingsToSave) + } + return status.Errorf(codes.Unimplemented, "method SaveDNSSettings is not implemented") +} diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index f0f8ea06c..a05ad49f7 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -1128,19 +1128,20 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - defaultGroup, err := account.GetGroupAll() - if err != nil { - return nil, err + newGroup1 := &Group{ + ID: group1ID, + Name: group1ID, + } + + newGroup2 := &Group{ + ID: group2ID, + Name: group2ID, } - newGroup1 := defaultGroup.Copy() - newGroup1.ID = group1ID - newGroup2 := defaultGroup.Copy() - newGroup2.ID = group2ID account.Groups[newGroup1.ID] = newGroup1 account.Groups[newGroup2.ID] = newGroup2 - err = am.Store.SaveAccount(account) + err := am.Store.SaveAccount(account) if err != nil { return nil, err } diff --git a/management/server/peer.go b/management/server/peer.go index 1e4bbd4f2..b3408acd6 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -323,16 +323,19 @@ func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap, aclPeers := account.getPeersByACL(peerPubKey) routesUpdate := account.getRoutesToSync(peerPubKey, aclPeers) - var zones []nbdns.CustomZone - peersCustomZone := getPeersCustomZone(account, am.dnsDomain) - if peersCustomZone.Domain != "" { - zones = append(zones, peersCustomZone) + dnsManagementStatus := account.getPeerDNSManagementStatus(peerPubKey) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, } - dnsUpdate := nbdns.Config{ - ServiceEnable: true, - CustomZones: zones, - NameServerGroups: getPeerNSGroups(account, peerPubKey), + if dnsManagementStatus { + var zones []nbdns.CustomZone + peersCustomZone := getPeersCustomZone(account, am.dnsDomain) + if peersCustomZone.Domain != "" { + zones = append(zones, peersCustomZone) + } + dnsUpdate.CustomZones = zones + dnsUpdate.NameServerGroups = getPeerNSGroups(account, peerPubKey) } return &NetworkMap{