Compare commits

...

16 Commits

Author SHA1 Message Date
braginini
45de3c1b06 Indicate to IdP that invite was redeemed 2022-10-18 21:03:25 +02:00
Maycon Santos
04e4407ea7 Add anonymous usage metrics collection (#508)
This will help us understand usage on self-hosted deployments

The collection may be disabled by using the flag --disable-anonymous-metrics or 
NETBIRD_DISABLE_ANONYMOUS_METRICS in setup.env
2022-10-16 13:33:46 +02:00
Misha Bragin
06055af361 Super user invites (#483)
This PR brings user invites logic to the Management service
via HTTP API. 
The POST /users/ API endpoint creates a new user in the Idp
and then in the local storage. 
Once the invited user signs ups, the account invitation is redeemed.
There are a few limitations.
This works only with an enabled IdP manager.
Users that already have a registered account can't be invited.
2022-10-13 18:26:31 +02:00
Maycon Santos
abd1230a69 Disable uninstall message when upgrade is silent (#505)
Fix a problem with $INSTDIR pointing to subfolder
2022-10-13 15:00:39 +02:00
Maycon Santos
f7de12daf8 Support custom redirect URIs (#499) 2022-10-12 12:25:46 +02:00
Maycon Santos
c49fb0c40c Get windows version from registry (#502)
Avoid problems with localization by retrieving
version information from registry

Return a default 0.0.0.0 if operation fails
2022-10-12 12:25:33 +02:00
Maycon Santos
6e9a162877 Seticon only when status changes (#504)
* Seticon only when status changes

This prevents a memory leak with the systray lib
when setting the icon every 2 seconds causes a large memory consumption

see https://github.com/getlantern/systray/issues/135

* Use fork with permanent fix
2022-10-12 12:25:06 +02:00
Maycon Santos
b4e03f4616 Feature/add nameservers API endpoint (#491)
Add nameservers endpoint and Open API definition

updated open api generator cli
2022-10-10 11:06:54 +02:00
Misha Bragin
369a7ef345 Add SSO MFA demo gif (#489) 2022-10-10 11:06:25 +02:00
Maycon Santos
c88e6a7342 Run tests only on branch main and on pull requests (#492)
* Use reusable workflow and control push and pr test exec

* use format

* use path ref

* Run tests on push to main and on PRs
2022-10-03 00:17:16 +05:00
Maycon Santos
2cd9b11e7d Add DNS nameserver support to management (#484)
Add DNS package and Nameserver group objects

Add CRUD operations for Nameserver Groups to account manager

Add Routes and Nameservers to Account Copy method

Run docker tests with timeout and serial flags
2022-09-30 16:47:11 +05:00
Maycon Santos
93d20e370b Add incoming routing rules (#486)
add an income firewall rule for each routing pair
the pair for the income rule has inverted
source and destination
2022-09-30 14:39:15 +05:00
Maycon Santos
878ca6db22 Check if domain from claim is valid (#485)
If domain is invalid we call GetAccountByUserOrAccountId
2022-09-29 13:51:18 +05:00
braginini
2033650908 Remove IdP client secret validation 2022-09-26 18:58:14 +02:00
Misha Bragin
34c1c7d901 Add hostname, userID, ui version to the HTTP API peer response (#479) 2022-09-26 18:02:45 +02:00
Misha Bragin
051fd3a4d7 Fix Management and Signal gRPC client stream leak (#482) 2022-09-26 18:02:20 +02:00
63 changed files with 4298 additions and 523 deletions

View File

@@ -1,5 +1,10 @@
name: Test Code Darwin name: Test Code Darwin
on: [push,pull_request]
on:
push:
branches:
- main
pull_request:
jobs: jobs:
test: test:

View File

@@ -1,5 +1,10 @@
name: Test Code Linux name: Test Code Linux
on: [push,pull_request]
on:
push:
branches:
- main
pull_request:
jobs: jobs:
test: test:
@@ -75,13 +80,13 @@ jobs:
- run: chmod +x *testing.bin - run: chmod +x *testing.bin
- name: Run Iface tests in docker - name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
- name: Run RouteManager tests in docker - name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker - name: Run Engine tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker - name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -1,5 +1,10 @@
name: Test Code Windows name: Test Code Windows
on: [push,pull_request]
on:
push:
branches:
- main
pull_request:
jobs: jobs:
pre: pre:

View File

@@ -1,5 +1,10 @@
name: Test Docker Compose Linux name: Test Docker Compose Linux
on: [push,pull_request]
on:
push:
branches:
- main
pull_request:
jobs: jobs:
test: test:
@@ -51,6 +56,7 @@ jobs:
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
run: | run: |
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
@@ -58,6 +64,8 @@ jobs:
grep AUTH_SUPPORTED_SCOPES docker-compose.yml | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES" grep AUTH_SUPPORTED_SCOPES docker-compose.yml | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep USE_AUTH0 docker-compose.yml | grep $CI_NETBIRD_USE_AUTH0 grep USE_AUTH0 docker-compose.yml | grep $CI_NETBIRD_USE_AUTH0
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "http://localhost:33073" grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "http://localhost:33073"
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
- name: run docker compose up - name: run docker compose up
working-directory: infrastructure_files working-directory: infrastructure_files

View File

@@ -41,7 +41,7 @@ builds:
- arm64 - arm64
- arm - arm
ldflags: ldflags:
- -s -w -X main.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/client/system.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-signal - id: netbird-signal

View File

@@ -63,8 +63,7 @@ NetBird creates an overlay peer-to-peer network connecting machines automaticall
### Secure peer-to-peer VPN with SSO and MFA in minutes ### Secure peer-to-peer VPN with SSO and MFA in minutes
<p float="left" align="middle"> <p float="left" align="middle">
<img src="docs/media/peerA.gif" width="400"/> <img src="docs/media/netbird-sso-mfa-demo.gif" width="800"/>
<img src="docs/media/peerB.gif" width="400"/>
</p> </p>
**Note**: The `main` branch may be in an *unstable or even broken state* during development. **Note**: The `main` branch may be in an *unstable or even broken state* during development.

View File

@@ -101,6 +101,7 @@ done:
Pop $2 Pop $2
Exch $1 Exch $1
FunctionEnd FunctionEnd
!macro GetAppFromCommand in out !macro GetAppFromCommand in out
Push "${in}" Push "${in}"
Call GetAppFromCommand Call GetAppFromCommand
@@ -117,7 +118,7 @@ Call GetAppFromCommand ; Remove quotes and parameters from UninstCommand
Pop $0 Pop $0
Pop $1 Pop $1
GetFullPathName $2 "$0\.." GetFullPathName $2 "$0\.."
ExecWait '"$0" $1 _?=$2' ExecWait '"$0" /S $1 _?=$2'
Delete "$0" ; Extra cleanup because we used _?= Delete "$0" ; Extra cleanup because we used _?=
RMDir "$2" RMDir "$2"
Pop $2 Pop $2
@@ -126,30 +127,27 @@ Pop $0
!macroend !macroend
Function .onInit Function .onInit
StrCpy $INSTDIR "${INSTALL_DIR}"
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\Wiretrustee" "UninstallString"
${If} $R0 != ""
MessageBox MB_YESNO|MB_ICONQUESTION "Wiretrustee is installed. We must remove it before installing Netbird. Procced?" IDNO noWTUninstOld
!insertmacro UninstallPreviousNSIS $R0 "/NoMsgBox"
noWTUninstOld:
${EndIf}
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString" ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
${If} $R0 != "" ${If} $R0 != ""
MessageBox MB_YESNO|MB_ICONQUESTION "$(^NAME) is already installed. Do you want to remove the previous version?" IDNO noUninstOld # if silent install jump to uninstall step
!insertmacro UninstallPreviousNSIS $R0 "/NoMsgBox" IfSilent uninstall
noUninstOld:
MessageBox MB_YESNO|MB_ICONQUESTION "NetBird is already installed. We must remove it before installing upgrading NetBird. Proceed?" IDNO done IDYES uninstall
uninstall:
!insertmacro UninstallPreviousNSIS $R0 "/NoMsgBox"
done:
${EndIf} ${EndIf}
FunctionEnd FunctionEnd
###################################################################### ######################################################################
Section -MainProgram Section -MainProgram
${INSTALL_TYPE} ${INSTALL_TYPE}
SetOverwrite ifnewer # SetOverwrite ifnewer
SetOutPath "$INSTDIR" SetOutPath "$INSTDIR"
File /r "..\\dist\\netbird_windows_amd64\\" File /r "..\\dist\\netbird_windows_amd64\\"
SectionEnd SectionEnd
###################################################################### ######################################################################
Section -Icons_Reg Section -Icons_Reg
@@ -172,24 +170,29 @@ SetShellVarContext current
CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}" CreateShortCut "$SMPROGRAMS\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}" CreateShortCut "$DESKTOP\${APP_NAME}.lnk" "$INSTDIR\${UI_APP_EXE}"
SetShellVarContext all SetShellVarContext all
SectionEnd
Section -Post
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service install' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service install'
Exec '"$INSTDIR\${MAIN_APP_EXE}" service start' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service start'
# sleep a bit for visibility # sleep a bit for visibility
Sleep 1000 Sleep 1000
SectionEnd SectionEnd
###################################################################### ######################################################################
Section Uninstall Section Uninstall
${INSTALL_TYPE} ${INSTALL_TYPE}
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
Exec '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client # kill ui client
ExecWait `taskkill /im ${UI_APP_EXE}.exe` ExecWait `taskkill /im ${UI_APP_EXE}.exe`
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
Sleep 3000 Sleep 3000
Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}"
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
SetShellVarContext current SetShellVarContext current

View File

@@ -292,9 +292,6 @@ func isProviderConfigValid(config ProviderConfig) error {
if config.ClientID == "" { if config.ClientID == "" {
return fmt.Errorf(errorMSGFormat, "Client ID") return fmt.Errorf(errorMSGFormat, "Client ID")
} }
if config.ClientSecret == "" {
return fmt.Errorf(errorMSGFormat, "Client Secret")
}
if config.TokenEndpoint == "" { if config.TokenEndpoint == "" {
return fmt.Errorf(errorMSGFormat, "Token Endpoint") return fmt.Errorf(errorMSGFormat, "Token Endpoint")
} }

View File

@@ -9,14 +9,16 @@ import (
import "github.com/google/nftables" import "github.com/google/nftables"
const ( const (
ipv6Forwarding = "netbird-rt-ipv6-forwarding" ipv6Forwarding = "netbird-rt-ipv6-forwarding"
ipv4Forwarding = "netbird-rt-ipv4-forwarding" ipv4Forwarding = "netbird-rt-ipv4-forwarding"
ipv6Nat = "netbird-rt-ipv6-nat" ipv6Nat = "netbird-rt-ipv6-nat"
ipv4Nat = "netbird-rt-ipv4-nat" ipv4Nat = "netbird-rt-ipv4-nat"
natFormat = "netbird-nat-%s" natFormat = "netbird-nat-%s"
forwardingFormat = "netbird-fwd-%s" forwardingFormat = "netbird-fwd-%s"
ipv6 = "ipv6" inNatFormat = "netbird-nat-in-%s"
ipv4 = "ipv4" inForwardingFormat = "netbird-fwd-in-%s"
ipv6 = "ipv6"
ipv4 = "ipv4"
) )
func genKey(format string, input string) string { func genKey(format string, input string) string {
@@ -53,3 +55,13 @@ func NewFirewall(parentCTX context.Context) firewallManager {
return manager return manager
} }
func getInPair(pair routerPair) routerPair {
return routerPair{
ID: pair.ID,
// invert source/destination
source: pair.destination,
destination: pair.source,
masquerade: pair.masquerade,
}
}

View File

@@ -311,7 +311,37 @@ func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
i.mux.Lock() i.mux.Lock()
defer i.mux.Unlock() defer i.mux.Unlock()
err := i.insertRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, routingFinalForwardJump, getInPair(pair))
if err != nil {
return err
}
if !pair.masquerade {
return nil
}
err = i.insertRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, pair)
if err != nil {
return err
}
err = i.insertRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, routingFinalNatJump, getInPair(pair))
if err != nil {
return err
}
return nil
}
// insertRoutingRule inserts an iptable rule
func (i *iptablesManager) insertRoutingRule(keyFormat, table, chain, jump string, pair routerPair) error {
var err error var err error
prefix := netip.MustParsePrefix(pair.source) prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4 ipVersion := ipv4
iptablesClient := i.ipv4Client iptablesClient := i.ipv4Client
@@ -320,43 +350,22 @@ func (i *iptablesManager) InsertRoutingRules(pair routerPair) error {
ipVersion = ipv6 ipVersion = ipv6
} }
forwardRuleKey := genKey(forwardingFormat, pair.ID) ruleKey := genKey(keyFormat, pair.ID)
forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, pair.source, pair.destination) rule := genRuleSpec(jump, ruleKey, pair.source, pair.destination)
existingRule, found := i.rules[ipVersion][forwardRuleKey] existingRule, found := i.rules[ipVersion][ruleKey]
if found { if found {
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...) err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil { if err != nil {
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err) return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
} }
delete(i.rules[ipVersion], forwardRuleKey) delete(i.rules[ipVersion], ruleKey)
} }
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...) err = iptablesClient.Insert(table, chain, 1, rule...)
if err != nil { if err != nil {
return fmt.Errorf("iptables: error while adding new forwarding rule for %s: %v", pair.destination, err) return fmt.Errorf("iptables: error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
} }
i.rules[ipVersion][forwardRuleKey] = forwardRule i.rules[ipVersion][ruleKey] = rule
if !pair.masquerade {
return nil
}
natRuleKey := genKey(natFormat, pair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, pair.source, pair.destination)
existingRule, found = i.rules[ipVersion][natRuleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing nat rulefor %s: %v", pair.destination, err)
}
delete(i.rules[ipVersion], natRuleKey)
}
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
if err != nil {
return fmt.Errorf("iptables: error while adding new nat rulefor %s: %v", pair.destination, err)
}
i.rules[ipVersion][natRuleKey] = natRule
return nil return nil
} }
@@ -366,7 +375,37 @@ func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
i.mux.Lock() i.mux.Lock()
defer i.mux.Unlock() defer i.mux.Unlock()
err := i.removeRoutingRule(forwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(inForwardingFormat, iptablesFilterTable, iptablesRoutingForwardingChain, getInPair(pair))
if err != nil {
return err
}
if !pair.masquerade {
return nil
}
err = i.removeRoutingRule(natFormat, iptablesNatTable, iptablesRoutingNatChain, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(inNatFormat, iptablesNatTable, iptablesRoutingNatChain, getInPair(pair))
if err != nil {
return err
}
return nil
}
// removeRoutingRule removes an iptables rule
func (i *iptablesManager) removeRoutingRule(keyFormat, table, chain string, pair routerPair) error {
var err error var err error
prefix := netip.MustParsePrefix(pair.source) prefix := netip.MustParsePrefix(pair.source)
ipVersion := ipv4 ipVersion := ipv4
iptablesClient := i.ipv4Client iptablesClient := i.ipv4Client
@@ -375,29 +414,23 @@ func (i *iptablesManager) RemoveRoutingRules(pair routerPair) error {
ipVersion = ipv6 ipVersion = ipv6
} }
forwardRuleKey := genKey(forwardingFormat, pair.ID) ruleKey := genKey(keyFormat, pair.ID)
existingRule, found := i.rules[ipVersion][forwardRuleKey] existingRule, found := i.rules[ipVersion][ruleKey]
if found { if found {
err = iptablesClient.DeleteIfExists(iptablesFilterTable, iptablesRoutingForwardingChain, existingRule...) err = iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil { if err != nil {
return fmt.Errorf("iptables: error while removing existing forwarding rule for %s: %v", pair.destination, err) return fmt.Errorf("iptables: error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.destination, err)
} }
} }
delete(i.rules[ipVersion], forwardRuleKey) delete(i.rules[ipVersion], ruleKey)
if !pair.masquerade {
return nil
}
natRuleKey := genKey(natFormat, pair.ID)
existingRule, found = i.rules[ipVersion][natRuleKey]
if found {
err = iptablesClient.DeleteIfExists(iptablesNatTable, iptablesRoutingNatChain, existingRule...)
if err != nil {
return fmt.Errorf("iptables: error while removing existing nat rule for %s: %v", pair.destination, err)
}
}
delete(i.rules[ipVersion], natRuleKey)
return nil return nil
} }
func getIptablesRuleType(table string) string {
ruleType := "forwarding"
if table == iptablesNatTable {
ruleType = "nat"
}
return ruleType
}

View File

@@ -159,6 +159,17 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.True(t, found, "forwarding rule should exist in the manager map") require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match") require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := genKey(natFormat, testCase.inputPair.ID) natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination) natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
@@ -172,7 +183,23 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
} else { } else {
require.False(t, exists, "nat rule should not be created") require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][natRuleKey] _, foundNat := manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, foundNat, "nat rule should exist in the map") require.False(t, foundNat, "nat rule should not exist in the map")
}
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
if testCase.inputPair.masquerade {
require.True(t, exists, "income nat rule should be created")
foundNatRule, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[testCase.ipVersion][inNatRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map")
} }
}) })
} }
@@ -213,12 +240,24 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...) err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
err = iptablesClient.Insert(iptablesFilterTable, iptablesRoutingForwardingChain, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := genKey(natFormat, testCase.inputPair.ID) natRuleKey := genKey(natFormat, testCase.inputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination) natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.inputPair.source, testCase.inputPair.destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...) err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, getInPair(testCase.inputPair).source, getInPair(testCase.inputPair).destination)
err = iptablesClient.Insert(iptablesNatTable, iptablesRoutingNatChain, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
delete(manager.rules, ipv4) delete(manager.rules, ipv4)
delete(manager.rules, ipv6) delete(manager.rules, ipv6)
@@ -235,12 +274,26 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
_, found := manager.rules[testCase.ipVersion][forwardRuleKey] _, found := manager.rules[testCase.ipVersion][forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map") require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesFilterTable, iptablesRoutingForwardingChain, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesFilterTable, iptablesRoutingForwardingChain)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[testCase.ipVersion][inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...) exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, natRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain) require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "nat rule should not exist") require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][natRuleKey] _, found = manager.rules[testCase.ipVersion][natRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map") require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(iptablesNatTable, iptablesRoutingNatChain, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s %s table and %s chain", testCase.ipVersion, iptablesNatTable, iptablesRoutingNatChain)
require.False(t, exists, "income nat rule should not exist")
_, found = manager.rules[testCase.ipVersion][inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
}) })
} }

View File

@@ -12,7 +12,6 @@ import (
) )
import "github.com/google/nftables" import "github.com/google/nftables"
//
const ( const (
nftablesTable = "netbird-rt" nftablesTable = "netbird-rt"
nftablesRoutingForwardingChain = "netbird-rt-fwd" nftablesRoutingForwardingChain = "netbird-rt-fwd"
@@ -248,53 +247,77 @@ func (n *nftablesManager) InsertRoutingRules(pair routerPair) error {
n.mux.Lock() n.mux.Lock()
defer n.mux.Unlock() defer n.mux.Unlock()
err := n.refreshRulesMap()
if err != nil {
return err
}
err = n.insertRoutingRule(forwardingFormat, nftablesRoutingForwardingChain, pair, false)
if err != nil {
return err
}
err = n.insertRoutingRule(inForwardingFormat, nftablesRoutingForwardingChain, getInPair(pair), false)
if err != nil {
return err
}
if pair.masquerade {
err = n.insertRoutingRule(natFormat, nftablesRoutingNatChain, pair, true)
if err != nil {
return err
}
err = n.insertRoutingRule(inNatFormat, nftablesRoutingNatChain, getInPair(pair), true)
if err != nil {
return err
}
}
err = n.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err)
}
return nil
}
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (n *nftablesManager) insertRoutingRule(format, chain string, pair routerPair, isNat bool) error {
prefix := netip.MustParsePrefix(pair.source) prefix := netip.MustParsePrefix(pair.source)
sourceExp := generateCIDRMatcherExpressions("source", pair.source) sourceExp := generateCIDRMatcherExpressions("source", pair.source)
destExp := generateCIDRMatcherExpressions("destination", pair.destination) destExp := generateCIDRMatcherExpressions("destination", pair.destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) var expression []expr.Any
fwdKey := genKey(forwardingFormat, pair.ID) if isNat {
if prefix.Addr().Unmap().Is4() { expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv4,
Chain: n.chains[ipv4][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(fwdKey),
})
} else { } else {
n.rules[fwdKey] = n.conn.InsertRule(&nftables.Rule{ expression = append(sourceExp, append(destExp, exprCounterAccept...)...)
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(fwdKey),
})
} }
if pair.masquerade { ruleKey := genKey(format, pair.ID)
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
natKey := genKey(natFormat, pair.ID)
if prefix.Addr().Unmap().Is4() { _, exists := n.rules[ruleKey]
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{ if exists {
Table: n.tableIPv4, err := n.removeRoutingRule(format, pair)
Chain: n.chains[ipv4][nftablesRoutingNatChain], if err != nil {
Exprs: natExp, return err
UserData: []byte(natKey),
})
} else {
n.rules[natKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(natKey),
})
} }
} }
err := n.conn.Flush() if prefix.Addr().Unmap().Is4() {
if err != nil { n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.destination, err) Table: n.tableIPv4,
Chain: n.chains[ipv4][chain],
Exprs: expression,
UserData: []byte(ruleKey),
})
} else {
n.rules[ruleKey] = n.conn.InsertRule(&nftables.Rule{
Table: n.tableIPv6,
Chain: n.chains[ipv6][chain],
Exprs: expression,
UserData: []byte(ruleKey),
})
} }
return nil return nil
} }
@@ -309,26 +332,26 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
return err return err
} }
fwdKey := genKey(forwardingFormat, pair.ID) err = n.removeRoutingRule(forwardingFormat, pair)
natKey := genKey(natFormat, pair.ID) if err != nil {
fwdRule, found := n.rules[fwdKey] return err
if found {
err = n.conn.DelRule(fwdRule)
if err != nil {
return fmt.Errorf("nftables: unable to remove forwarding rule for %s: %v", pair.destination, err)
}
log.Debugf("nftables: removing forwarding rule for %s", pair.destination)
delete(n.rules, fwdKey)
} }
natRule, found := n.rules[natKey]
if found { err = n.removeRoutingRule(inForwardingFormat, getInPair(pair))
err = n.conn.DelRule(natRule) if err != nil {
if err != nil { return err
return fmt.Errorf("nftables: unable to remove nat rule for %s: %v", pair.destination, err)
}
log.Debugf("nftables: removing nat rule for %s", pair.destination)
delete(n.rules, natKey)
} }
err = n.removeRoutingRule(natFormat, pair)
if err != nil {
return err
}
err = n.removeRoutingRule(inNatFormat, getInPair(pair))
if err != nil {
return err
}
err = n.conn.Flush() err = n.conn.Flush()
if err != nil { if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err) return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.destination, err)
@@ -337,6 +360,29 @@ func (n *nftablesManager) RemoveRoutingRules(pair routerPair) error {
return nil return nil
} }
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (n *nftablesManager) removeRoutingRule(format string, pair routerPair) error {
ruleKey := genKey(format, pair.ID)
rule, found := n.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := n.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.destination)
delete(n.rules, ruleKey)
}
return nil
}
// getPayloadDirectives get expression directives based on ip version and direction // getPayloadDirectives get expression directives based on ip version and direction
func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) { func getPayloadDirectives(direction string, isIPv4 bool, isIPv6 bool) (uint32, uint32, []byte) {
switch { switch {

View File

@@ -189,6 +189,45 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
} }
require.Equal(t, 1, found, "should find at least 1 rule to test") require.Equal(t, 1, found, "should find at least 1 rule to test")
} }
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
testingExpression = append(sourceExp, destExp...)
inFwdRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
found = 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.inputPair.masquerade {
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
found := 0
for _, registeredChains := range manager.chains {
for _, chain := range registeredChains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
}) })
} }
} }
@@ -241,6 +280,28 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
UserData: []byte(natRuleKey), UserData: []byte(natRuleKey),
}) })
sourceExp = generateCIDRMatcherExpressions("source", getInPair(testCase.inputPair).source)
destExp = generateCIDRMatcherExpressions("destination", getInPair(testCase.inputPair).destination)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...)
inForwardRuleKey := genKey(inForwardingFormat, testCase.inputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingForwardingChain],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...)
inNatRuleKey := genKey(inNatFormat, testCase.inputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: table,
Chain: manager.chains[testCase.ipVersion][nftablesRoutingNatChain],
Exprs: natExp,
UserData: []byte(inNatRuleKey),
})
err = nftablesTestingClient.Flush() err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
@@ -259,8 +320,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 { if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should exist") require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should exist") require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
} }
} }
} }

View File

@@ -1,36 +1,17 @@
package system package system
import ( import (
"bytes"
"context" "context"
"fmt"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
"os" "os"
"os/exec"
"runtime" "runtime"
"strings"
) )
// GetInfo retrieves and parses the system information // GetInfo retrieves and parses the system information
func GetInfo(ctx context.Context) *Info { func GetInfo(ctx context.Context) *Info {
cmd := exec.Command("cmd", "ver") ver := getOSVersion()
cmd.Stdin = strings.NewReader("some")
var out bytes.Buffer
var stderr bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
panic(err)
}
osStr := strings.Replace(out.String(), "\n", "", -1)
osStr = strings.Replace(osStr, "\r\n", "", -1)
tmp1 := strings.Index(osStr, "[Version")
tmp2 := strings.Index(osStr, "]")
var ver string
if tmp1 == -1 || tmp2 == -1 {
ver = "unknown"
} else {
ver = osStr[tmp1+9 : tmp2]
}
gio := &Info{Kernel: "windows", OSVersion: ver, Core: ver, Platform: "unknown", OS: "windows", GoOS: runtime.GOOS, CPUs: runtime.NumCPU()} gio := &Info{Kernel: "windows", OSVersion: ver, Core: ver, Platform: "unknown", OS: "windows", GoOS: runtime.GOOS, CPUs: runtime.NumCPU()}
gio.Hostname, _ = os.Hostname() gio.Hostname, _ = os.Hostname()
gio.WiretrusteeVersion = NetbirdVersion() gio.WiretrusteeVersion = NetbirdVersion()
@@ -38,3 +19,37 @@ func GetInfo(ctx context.Context) *Info {
return gio return gio
} }
func getOSVersion() string {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\Microsoft\Windows NT\CurrentVersion`, registry.QUERY_VALUE)
if err != nil {
log.Error(err)
return "0.0.0.0"
}
defer func() {
deferErr := k.Close()
if deferErr != nil {
log.Error(deferErr)
}
}()
major, _, err := k.GetIntegerValue("CurrentMajorVersionNumber")
if err != nil {
log.Error(err)
}
minor, _, err := k.GetIntegerValue("CurrentMinorVersionNumber")
if err != nil {
log.Error(err)
}
build, _, err := k.GetStringValue("CurrentBuildNumber")
if err != nil {
log.Error(err)
}
// Update Build Revision
ubr, _, err := k.GetIntegerValue("UBR")
if err != nil {
log.Error(err)
}
ver := fmt.Sprintf("%d.%d.%s.%d", major, minor, build, ubr)
return ver
}

View File

@@ -324,12 +324,12 @@ func (s *serviceClient) updateStatus() error {
return err return err
} }
if status.Status == string(internal.StatusConnected) { if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() {
systray.SetIcon(s.icConnected) systray.SetIcon(s.icConnected)
s.mStatus.SetTitle("Connected") s.mStatus.SetTitle("Connected")
s.mUp.Disable() s.mUp.Disable()
s.mDown.Enable() s.mDown.Enable()
} else { } else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
systray.SetIcon(s.icDisconnected) systray.SetIcon(s.icDisconnected)
s.mStatus.SetTitle("Disconnected") s.mStatus.SetTitle("Disconnected")
s.mDown.Disable() s.mDown.Disable()

6
dns/dns.go Normal file
View File

@@ -0,0 +1,6 @@
// Package dns implement dns types and standard methods and functions
// to parse and normalize dns records and configuration
package dns
// DefaultDNSPort well-known port number
const DefaultDNSPort = 53

184
dns/nameserver.go Normal file
View File

@@ -0,0 +1,184 @@
package dns
import (
"fmt"
"net/netip"
"net/url"
"strconv"
"strings"
)
const (
// MaxGroupNameChar maximum group name size
MaxGroupNameChar = 40
// InvalidNameServerType invalid nameserver type
InvalidNameServerType NameServerType = iota
// UDPNameServerType udp nameserver type
UDPNameServerType
)
const (
// InvalidNameServerTypeString invalid nameserver type as string
InvalidNameServerTypeString = "invalid"
// UDPNameServerTypeString udp nameserver type as string
UDPNameServerTypeString = "udp"
)
// NameServerType nameserver type
type NameServerType int
// String returns nameserver type string
func (n NameServerType) String() string {
switch n {
case UDPNameServerType:
return UDPNameServerTypeString
default:
return InvalidNameServerTypeString
}
}
// ToNameServerType returns a nameserver type
func ToNameServerType(typeString string) NameServerType {
switch typeString {
case UDPNameServerTypeString:
return UDPNameServerType
default:
return InvalidNameServerType
}
}
// NameServerGroup group of nameservers and with group ids
type NameServerGroup struct {
// ID identifier of group
ID string
// Name group name
Name string
// Description group description
Description string
// NameServers list of nameservers
NameServers []NameServer
// Groups list of peer group IDs to distribute the nameservers information
Groups []string
// Enabled group status
Enabled bool
}
// NameServer represents a DNS nameserver
type NameServer struct {
// IP address of nameserver
IP netip.Addr
// NSType nameserver type
NSType NameServerType
// Port nameserver listening port
Port int
}
// Copy copies a nameserver object
func (n *NameServer) Copy() *NameServer {
return &NameServer{
IP: n.IP,
NSType: n.NSType,
Port: n.Port,
}
}
// IsEqual compares one nameserver with the other
func (n *NameServer) IsEqual(other *NameServer) bool {
return other.IP == n.IP &&
other.NSType == n.NSType &&
other.Port == n.Port
}
// ParseNameServerURL parses a nameserver url in the format <type>://<ip>:<port>, e.g., udp://1.1.1.1:53
func ParseNameServerURL(nsURL string) (NameServer, error) {
parsedURL, err := url.Parse(nsURL)
if err != nil {
return NameServer{}, err
}
var ns NameServer
parsedScheme := strings.ToLower(parsedURL.Scheme)
nsType := ToNameServerType(parsedScheme)
if nsType == InvalidNameServerType {
return NameServer{}, fmt.Errorf("invalid nameserver url schema type, got %s", parsedScheme)
}
ns.NSType = nsType
parsedPort, err := strconv.Atoi(parsedURL.Port())
if err != nil {
return NameServer{}, fmt.Errorf("invalid nameserver url port, got %s", parsedURL.Port())
}
ns.Port = parsedPort
parsedAddr, err := netip.ParseAddr(parsedURL.Hostname())
if err != nil {
return NameServer{}, fmt.Errorf("invalid nameserver url IP, got %s", parsedURL.Hostname())
}
ns.IP = parsedAddr
return ns, nil
}
// Copy copies a nameserver group object
func (g *NameServerGroup) Copy() *NameServerGroup {
return &NameServerGroup{
ID: g.ID,
Name: g.Name,
Description: g.Description,
NameServers: g.NameServers,
Groups: g.Groups,
Enabled: g.Enabled,
}
}
// IsEqual compares one nameserver group with the other
func (g *NameServerGroup) IsEqual(other *NameServerGroup) bool {
return other.ID == g.ID &&
other.Name == g.Name &&
other.Description == g.Description &&
compareNameServerList(g.NameServers, other.NameServers) &&
compareGroupsList(g.Groups, other.Groups)
}
func compareNameServerList(list, other []NameServer) bool {
if len(list) != len(other) {
return false
}
for _, ns := range list {
if !containsNameServer(ns, other) {
return false
}
}
return true
}
func containsNameServer(element NameServer, list []NameServer) bool {
for _, ns := range list {
if ns.IsEqual(&element) {
return true
}
}
return false
}
func compareGroupsList(list, other []string) bool {
if len(list) != len(other) {
return false
}
for _, id := range list {
match := false
for _, otherID := range other {
if id == otherID {
match = true
break
}
}
if !match {
return false
}
}
return true
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 887 KiB

9
go.mod
View File

@@ -32,7 +32,7 @@ require (
github.com/c-robinson/iplib v1.0.3 github.com/c-robinson/iplib v1.0.3
github.com/coreos/go-iptables v0.6.0 github.com/coreos/go-iptables v0.6.0
github.com/creack/pty v1.1.18 github.com/creack/pty v1.1.18
github.com/eko/gocache/v2 v2.3.1 github.com/eko/gocache/v3 v3.1.1
github.com/getlantern/systray v1.2.1 github.com/getlantern/systray v1.2.1
github.com/gliderlabs/ssh v0.3.4 github.com/gliderlabs/ssh v0.3.4
github.com/google/nftables v0.0.0-20220808154552-2eca00135732 github.com/google/nftables v0.0.0-20220808154552-2eca00135732
@@ -41,7 +41,7 @@ require (
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.7.1 github.com/stretchr/testify v1.8.0
golang.org/x/net v0.0.0-20220630215102-69896b714898 golang.org/x/net v0.0.0-20220630215102-69896b714898
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
) )
@@ -99,6 +99,7 @@ require (
github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect
github.com/yuin/goldmark v1.4.1 // indirect github.com/yuin/goldmark v1.4.1 // indirect
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf // indirect
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect golang.org/x/image v0.0.0-20200430140353-33d19683fad8 // indirect
golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect
@@ -112,9 +113,11 @@ require (
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
honnef.co/go/tools v0.2.2 // indirect honnef.co/go/tools v0.2.2 // indirect
k8s.io/apimachinery v0.23.5 // indirect k8s.io/apimachinery v0.23.5 // indirect
) )
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84 replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c

17
go.sum
View File

@@ -134,8 +134,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cu
github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM= github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM=
github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/eko/gocache/v2 v2.3.1 h1:8MMkfqGJ0KIA9OXT0rXevcEIrU16oghrGDiIDJDFCa0= github.com/eko/gocache/v3 v3.1.1 h1:r3CBwLnqPkcK56h9Do2CWw1kZ4TeKK0wDE1Oo/YZnhs=
github.com/eko/gocache/v2 v2.3.1/go.mod h1:l2z8OmpZHL0CpuzDJtxm267eF3mZW1NqUsMj+sKrbUs= github.com/eko/gocache/v3 v3.1.1/go.mod h1:UpP/LyHAioP/a/dizgl0MpgZ3A3CkS4NbG/mWkGTQ9M=
github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc=
github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc=
github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs=
@@ -178,8 +178,6 @@ github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55 h1:XYzSdCbkzOC0F
github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55/go.mod h1:6mmzY2kW1TOOrVy+r41Za2MxXM+hhqTtY3oBKd2AgFA= github.com/getlantern/hidden v0.0.0-20190325191715-f02dbb02be55/go.mod h1:6mmzY2kW1TOOrVy+r41Za2MxXM+hhqTtY3oBKd2AgFA=
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f h1:wrYrQttPS8FHIRSlsrcuKazukx/xqO/PpLZzZXsF+EA= github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f h1:wrYrQttPS8FHIRSlsrcuKazukx/xqO/PpLZzZXsF+EA=
github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f/go.mod h1:D5ao98qkA6pxftxoqzibIBBrLSUli+kYnJqrgBf9cIA= github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f/go.mod h1:D5ao98qkA6pxftxoqzibIBBrLSUli+kYnJqrgBf9cIA=
github.com/getlantern/systray v1.2.1 h1:udsC2k98v2hN359VTFShuQW6GGprRprw6kD6539JikI=
github.com/getlantern/systray v1.2.1/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM=
github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
@@ -474,6 +472,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRW
github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw=
github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84 h1:u8kpzR9ld1uAeH/BAXsS0SfcnhooLWeO7UgHSBVPD9I= github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84 h1:u8kpzR9ld1uAeH/BAXsS0SfcnhooLWeO7UgHSBVPD9I=
github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20220905002524-6ac14ad5ea84/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c h1:wK/s4nyZj/GF/kFJQjX6nqNfE0G3gcqd6hhnPCyp4sw=
github.com/netbirdio/systray v0.0.0-20221012095658-dc8eda872c0c/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
@@ -609,6 +609,7 @@ github.com/srwiley/rasterx v0.0.0-20200120212402-85cb7272f5e9/go.mod h1:mvWM0+15
github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
@@ -616,8 +617,9 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
@@ -676,6 +678,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf h1:oXVg4h2qJDd9htKxb5SCpFBHLipW6hXmL3qpUixS2jw=
golang.org/x/exp v0.0.0-20220518171630-0b5c67f07fdf/go.mod h1:yh0Ynu2b5ZUe3MQfp2nM0ecK7wsgouWTDN0FNeJuIys=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 h1:6WW6V3x1P/jokJBpRQYUJnMHRP6isStQwCozxnU7XQw= golang.org/x/image v0.0.0-20200430140353-33d19683fad8 h1:6WW6V3x1P/jokJBpRQYUJnMHRP6isStQwCozxnU7XQw=
@@ -1190,8 +1194,9 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View File

@@ -29,6 +29,8 @@ LETSENCRYPT_VOLUMESUFFIX="letsencrypt"
NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none" NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
NETBIRD_DISABLE_ANONYMOUS_METRICS=${NETBIRD_DISABLE_ANONYMOUS_METRICS:-false}
# exports # exports
export NETBIRD_DOMAIN export NETBIRD_DOMAIN
export NETBIRD_AUTH_CLIENT_ID export NETBIRD_AUTH_CLIENT_ID
@@ -45,6 +47,8 @@ export NETBIRD_MGMT_API_CERT_KEY_FILE
export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER export NETBIRD_AUTH_DEVICE_AUTH_PROVIDER
export NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID export NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID
export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT export NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT
export NETBIRD_AUTH_REDIRECT_URI
export NETBIRD_AUTH_SILENT_REDIRECT_URI
export TURN_USER export TURN_USER
export TURN_PASSWORD export TURN_PASSWORD
export TURN_MIN_PORT export TURN_MIN_PORT
@@ -53,3 +57,4 @@ export VOLUME_PREFIX
export MGMT_VOLUMESUFFIX export MGMT_VOLUMESUFFIX
export SIGNAL_VOLUMESUFFIX export SIGNAL_VOLUMESUFFIX
export LETSENCRYPT_VOLUMESUFFIX export LETSENCRYPT_VOLUMESUFFIX
export NETBIRD_DISABLE_ANONYMOUS_METRICS

View File

@@ -18,6 +18,8 @@ services:
- NGINX_SSL_PORT=443 - NGINX_SSL_PORT=443
- LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN - LETSENCRYPT_DOMAIN=$NETBIRD_DOMAIN
- LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL - LETSENCRYPT_EMAIL=$NETBIRD_LETSENCRYPT_EMAIL
- AUTH_REDIRECT_URI=$NETBIRD_AUTH_REDIRECT_URI
- AUTH_SILENT_REDIRECT_URI=$NETBIRD_AUTH_SILENT_REDIRECT_URI
volumes: volumes:
- $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/ - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt/
# Signal # Signal
@@ -46,7 +48,7 @@ services:
# # port and command for Let's Encrypt validation without dashboard container # # port and command for Let's Encrypt validation without dashboard container
# - 443:443 # - 443:443
# command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"] # command: ["--letsencrypt-domain", "$NETBIRD_DOMAIN", "--log-file", "console"]
command: ["--port", "443", "--log-file", "console"] command: ["--port", "443", "--log-file", "console", "--disable-anonymous-metrics=$NETBIRD_DISABLE_ANONYMOUS_METRICS"]
# Coturn # Coturn
coturn: coturn:
image: coturn/coturn image: coturn/coturn

View File

@@ -13,3 +13,10 @@ NETBIRD_AUTH_DEVICE_AUTH_PROVIDER="none"
NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID="" NETBIRD_AUTH_DEVICE_AUTH_CLIENT_ID=""
# e.g. hello@mydomain.com # e.g. hello@mydomain.com
NETBIRD_LETSENCRYPT_EMAIL="" NETBIRD_LETSENCRYPT_EMAIL=""
# if your IDP provider doesn't support fragmented URIs, configure custom
# redirect and silent redirect URIs, these will be concatenated into your NETBIRD_DOMAIN domain.
# NETBIRD_AUTH_REDIRECT_URI="/peers"
# NETBIRD_AUTH_SILENT_REDIRECT_URI="/add-peers"
# Disable anonymous metrics collection, see more information at https://netbird.io/docs/FAQ/metrics-collection
NETBIRD_DISABLE_ANONYMOUS_METRICS=false

View File

@@ -11,3 +11,4 @@ NETBIRD_USE_AUTH0=$CI_NETBIRD_USE_AUTH0
NETBIRD_AUTH_AUDIENCE=$CI_NETBIRD_AUTH_AUDIENCE NETBIRD_AUTH_AUDIENCE=$CI_NETBIRD_AUTH_AUDIENCE
# e.g. hello@mydomain.com # e.g. hello@mydomain.com
NETBIRD_LETSENCRYPT_EMAIL="" NETBIRD_LETSENCRYPT_EMAIL=""
NETBIRD_AUTH_REDIRECT_URI="/peers"

View File

@@ -109,7 +109,9 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return err return err
} }
stream, err := c.connectToStream(*serverPubKey) ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connectToStream(ctx, *serverPubKey)
if err != nil { if err != nil {
log.Debugf("failed to open Management Service stream: %s", err) log.Debugf("failed to open Management Service stream: %s", err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied { if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
@@ -145,7 +147,7 @@ func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error
return nil return nil
} }
func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) { func (c *GrpcClient) connectToStream(ctx context.Context, serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) {
req := &proto.SyncRequest{} req := &proto.SyncRequest{}
myPrivateKey := c.key myPrivateKey := c.key
@@ -156,9 +158,12 @@ func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.Management
log.Errorf("failed encrypting message: %s", err) log.Errorf("failed encrypting message: %s", err)
return nil, err return nil, err
} }
syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq} syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq}
return c.realClient.Sync(c.ctx, syncReq) sync, err := c.realClient.Sync(ctx, syncReq)
if err != nil {
return nil, err
}
return sync, nil
} }
func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error { func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error {

View File

@@ -1,12 +1,15 @@
package cmd package cmd
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"github.com/google/uuid"
httpapi "github.com/netbirdio/netbird/management/server/http" httpapi "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/metrics"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
@@ -161,6 +164,21 @@ var (
} }
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv) mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
installationID, err := getInstallationID(store)
if err != nil {
log.Errorf("cannot load TLS credentials: %v", err)
return err
}
fmt.Println("metrics ", disableMetrics)
if !disableMetrics {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
metricsWorker := metrics.NewWorker(ctx, installationID, store, peersUpdateManager)
go metricsWorker.Run()
}
var compatListener net.Listener var compatListener net.Listener
if mgmtPort != ManagementLegacyPort { if mgmtPort != ManagementLegacyPort {
// The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it // The Management gRPC server was running on port 33073 previously. Old agents that are already connected to it
@@ -228,6 +246,20 @@ func notifyStop(msg string) {
} }
} }
func getInstallationID(store server.Store) (string, error) {
installationID := store.GetInstallationID()
if installationID != "" {
return installationID, nil
}
installationID = strings.ToUpper(uuid.New().String())
err := store.SaveInstallationID(installationID)
if err != nil {
return "", err
}
return installationID, nil
}
func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) { func serveGRPC(grpcServer *grpc.Server, port int) (net.Listener, error) {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil { if err != nil {

View File

@@ -27,6 +27,7 @@ var (
mgmtConfig string mgmtConfig string
logLevel string logLevel string
logFile string logFile string
disableMetrics bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird-mgmt", Use: "netbird-mgmt",
@@ -66,6 +67,7 @@ func init() {
mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().StringVar(&certFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") mgmtCmd.Flags().StringVar(&certKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
mgmtCmd.Flags().BoolVar(&disableMetrics, "disable-anonymous-metrics", false, "disables push of anonymous usage metrics to NetBird")
rootCmd.MarkFlagRequired("config") //nolint rootCmd.MarkFlagRequired("config") //nolint
rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "") rootCmd.PersistentFlags().StringVar(&logLevel, "log-level", "info", "")

View File

@@ -1,4 +1,17 @@
#!/bin/bash #!/bin/bash
set -e
if ! which realpath > /dev/null 2>&1
then
echo realpath is not installed
echo run: brew install coreutils
exit 1
fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26 go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.26
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
protoc -I proto/ proto/management.proto --go_out=. --go-grpc_out=. protoc -I ./ ./management.proto --go_out=../ --go-grpc_out=../
cd "$old_pwd"

View File

@@ -3,8 +3,9 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/eko/gocache/v2/cache" "github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v2/store" cacheStore "github.com/eko/gocache/v3/store"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -15,6 +16,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"math/rand" "math/rand"
"reflect" "reflect"
"regexp"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -28,6 +30,11 @@ const (
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
) )
func cacheEntryExpiration() time.Duration {
r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds())
return time.Duration(r) * time.Millisecond
}
type AccountManager interface { type AccountManager interface {
GetOrCreateAccountByUser(userId, domain string) (*Account, error) GetOrCreateAccountByUser(userId, domain string) (*Account, error)
GetAccountByUser(userId string) (*Account, error) GetAccountByUser(userId string) (*Account, error)
@@ -39,11 +46,13 @@ type AccountManager interface {
autoGroups []string, autoGroups []string,
) (*SetupKey, error) ) (*SetupKey, error)
SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error) SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error)
CreateUser(accountID string, key *UserInfo) (*UserInfo, error)
ListSetupKeys(accountID string) ([]*SetupKey, error)
SaveUser(accountID string, key *User) (*UserInfo, error) SaveUser(accountID string, key *User) (*UserInfo, error)
GetSetupKey(accountID, keyID string) (*SetupKey, error) GetSetupKey(accountID, keyID string) (*SetupKey, error)
GetAccountById(accountId string) (*Account, error) GetAccountById(accountId string) (*Account, error)
GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error)
GetAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error)
IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error)
AccountExists(accountId string) (*bool, error) AccountExists(accountId string) (*bool, error)
GetPeer(peerKey string) (*Peer, error) GetPeer(peerKey string) (*Peer, error)
@@ -77,16 +86,25 @@ type AccountManager interface {
UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error) UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
DeleteRoute(accountID, routeID string) error DeleteRoute(accountID, routeID string) error
ListRoutes(accountID string) ([]*route.Route, error) ListRoutes(accountID string) ([]*route.Route, error)
ListSetupKeys(accountID string) ([]*SetupKey, error) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroup(accountID, nsGroupID string) error
ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error)
} }
type DefaultAccountManager struct { type DefaultAccountManager struct {
Store Store Store Store
// mutex to synchronise account operations (e.g. generating Peer IP address inside the Network) // mux to synchronise account operations (e.g. generating Peer IP address inside the Network)
mux sync.Mutex mux sync.Mutex
// cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID
cacheMux sync.Mutex
// cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded
cacheLoading map[string]chan struct{}
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
idpManager idp.Manager idpManager idp.Manager
cacheManager cache.CacheInterface cacheManager cache.CacheInterface[[]*idp.UserData]
ctx context.Context ctx context.Context
} }
@@ -105,6 +123,7 @@ type Account struct {
Groups map[string]*Group Groups map[string]*Group
Rules map[string]*Rule Rules map[string]*Rule
Routes map[string]*route.Route Routes map[string]*route.Route
NameServerGroups map[string]*nbdns.NameServerGroup
} }
type UserInfo struct { type UserInfo struct {
@@ -113,6 +132,7 @@ type UserInfo struct {
Name string `json:"name"` Name string `json:"name"`
Role string `json:"role"` Role string `json:"role"`
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
Status string `json:"-"`
} }
func (a *Account) Copy() *Account { func (a *Account) Copy() *Account {
@@ -141,15 +161,27 @@ func (a *Account) Copy() *Account {
rules[id] = rule.Copy() rules[id] = rule.Copy()
} }
routes := map[string]*route.Route{}
for id, route := range a.Routes {
routes[id] = route.Copy()
}
nsGroups := map[string]*nbdns.NameServerGroup{}
for id, nsGroup := range a.NameServerGroups {
nsGroups[id] = nsGroup.Copy()
}
return &Account{ return &Account{
Id: a.Id, Id: a.Id,
CreatedBy: a.CreatedBy, CreatedBy: a.CreatedBy,
SetupKeys: setupKeys, SetupKeys: setupKeys,
Network: a.Network.Copy(), Network: a.Network.Copy(),
Peers: peers, Peers: peers,
Users: users, Users: users,
Groups: groups, Groups: groups,
Rules: rules, Rules: rules,
Routes: routes,
NameServerGroups: nsGroups,
} }
} }
@@ -172,6 +204,8 @@ func BuildManager(
peersUpdateManager: peersUpdateManager, peersUpdateManager: peersUpdateManager,
idpManager: idpManager, idpManager: idpManager,
ctx: context.Background(), ctx: context.Background(),
cacheMux: sync.Mutex{},
cacheLoading: map[string]chan struct{}{},
} }
// if account has not default group // if account has not default group
@@ -188,9 +222,9 @@ func BuildManager(
} }
gocacheClient := gocache.New(CacheExpirationMax, 30*time.Minute) gocacheClient := gocache.New(CacheExpirationMax, 30*time.Minute)
gocacheStore := cacheStore.NewGoCache(gocacheClient, nil) gocacheStore := cacheStore.NewGoCache(gocacheClient)
am.cacheManager = cache.NewLoadable(am.loadFromCache, cache.New(gocacheStore)) am.cacheManager = cache.NewLoadable[[]*idp.UserData](am.loadAccount, cache.New[[]*idp.UserData](gocacheStore))
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
go func() { go func() {
@@ -235,11 +269,7 @@ func (am *DefaultAccountManager) warmupIDPCache() error {
} }
for accountID, users := range userData { for accountID, users := range userData {
rand.Seed(time.Now().UnixNano()) err = am.cacheManager.Set(am.ctx, accountID, users, cacheStore.WithExpiration(cacheEntryExpiration()))
r := rand.Intn(int(CacheExpirationMax.Milliseconds()-CacheExpirationMin.Milliseconds())) + int(CacheExpirationMin.Milliseconds())
expiration := time.Duration(r) * time.Millisecond
err = am.cacheManager.Set(am.ctx, accountID, users, &cacheStore.Options{Expiration: expiration})
if err != nil { if err != nil {
return err return err
} }
@@ -273,7 +303,7 @@ func (am *DefaultAccountManager) GetAccountByUserOrAccountId(
if err != nil { if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId) return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId)
} }
err = am.updateIDPMetadata(userId, account.Id) err = am.addAccountIDToIDPAppMeta(userId, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -287,10 +317,28 @@ func isNil(i idp.Manager) bool {
return i == nil || reflect.ValueOf(i).IsNil() return i == nil || reflect.ValueOf(i).IsNil()
} }
// updateIDPMetadata update user's app metadata in idp manager // addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) error { func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(userID string, account *Account) error {
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
err := am.idpManager.UpdateUserAppMetadata(userId, idp.AppMetadata{WTAccountId: accountID})
// user can be nil if it wasn't found (e.g., just created)
user, err := am.lookupUserInCache(userID, account)
if err != nil {
return err
}
if user != nil && user.AppMetadata.WTAccountID == account.Id {
// it was already set, so we skip the unnecessary update
log.Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
account.Id, userID)
return nil
}
err = am.idpManager.UpdateUserAppMetadata(userID, idp.AppMetadata{WTAccountID: account.Id})
if err != nil {
return err
}
if err != nil { if err != nil {
return status.Errorf( return status.Errorf(
codes.Internal, codes.Internal,
@@ -298,45 +346,113 @@ func (am *DefaultAccountManager) updateIDPMetadata(userId, accountID string) err
err, err,
) )
} }
// refresh cache to reflect the update
_, err = am.refreshCache(account.Id)
if err != nil {
return err
}
} }
return nil return nil
} }
func (am *DefaultAccountManager) loadFromCache(_ context.Context, accountID interface{}) (interface{}, error) { func (am *DefaultAccountManager) loadAccount(_ context.Context, accountID interface{}) ([]*idp.UserData, error) {
log.Debugf("account %s not found in cache, reloading", accountID)
return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID)) return am.idpManager.GetAccount(fmt.Sprintf("%v", accountID))
} }
func (am *DefaultAccountManager) lookupUserInCache(user *User, accountID string) (*idp.UserData, error) { func (am *DefaultAccountManager) lookupUserInCacheByEmail(email string, accountID string) (*idp.UserData, error) {
userData, err := am.lookupCache(map[string]*User{user.Id: user}, accountID) data, err := am.getAccountFromCache(accountID, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, datum := range userData { for _, datum := range data {
if datum.ID == user.Id { if datum.Email == email {
return datum, nil return datum, nil
} }
} }
return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", user.Id) return nil, nil
} }
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, accountID string) ([]*idp.UserData, error) { // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
data, err := am.cacheManager.Get(am.ctx, accountID) func (am *DefaultAccountManager) lookupUserInCache(userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]struct{}, len(account.Users))
for _, user := range account.Users {
users[user.Id] = struct{}{}
}
log.Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(users, account.Id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
userData := data.([]*idp.UserData) for _, datum := range userData {
if datum.ID == userID {
return datum, nil
}
}
return nil, nil
}
func (am *DefaultAccountManager) refreshCache(accountID string) ([]*idp.UserData, error) {
return am.getAccountFromCache(accountID, true)
}
// getAccountFromCache returns user data for a given account ensuring that cache load happens only once
func (am *DefaultAccountManager) getAccountFromCache(accountID string, forceReload bool) ([]*idp.UserData, error) {
am.cacheMux.Lock()
loadingChan := am.cacheLoading[accountID]
if loadingChan == nil {
loadingChan = make(chan struct{})
am.cacheLoading[accountID] = loadingChan
am.cacheMux.Unlock()
defer func() {
am.cacheMux.Lock()
delete(am.cacheLoading, accountID)
close(loadingChan)
am.cacheMux.Unlock()
}()
if forceReload {
err := am.cacheManager.Delete(am.ctx, accountID)
if err != nil {
return nil, err
}
}
return am.cacheManager.Get(am.ctx, accountID)
}
am.cacheMux.Unlock()
log.Debugf("one request to get account %s is already running", accountID)
select {
case <-loadingChan:
// channel has been closed meaning cache was loaded => simply return from cache
return am.cacheManager.Get(am.ctx, accountID)
case <-time.After(5 * time.Second):
return nil, fmt.Errorf("timeout while waiting for account %s cache to reload", accountID)
}
}
func (am *DefaultAccountManager) lookupCache(accountUsers map[string]struct{}, accountID string) ([]*idp.UserData, error) {
data, err := am.getAccountFromCache(accountID, false)
if err != nil {
return nil, err
}
userDataMap := make(map[string]struct{}) userDataMap := make(map[string]struct{})
for _, datum := range userData { for _, datum := range data {
userDataMap[datum.ID] = struct{}{} userDataMap[datum.ID] = struct{}{}
} }
// check whether we need to reload the cache // check whether we need to reload the cache
// the accountUsers ID list is the source of truth and all the users should be in the cache // the accountUsers ID list is the source of truth and all the users should be in the cache
reload := len(accountUsers) != len(userData) reload := len(accountUsers) != len(data)
for user := range accountUsers { for user := range accountUsers {
if _, ok := userDataMap[user]; !ok { if _, ok := userDataMap[user]; !ok {
reload = true reload = true
@@ -345,19 +461,13 @@ func (am *DefaultAccountManager) lookupCache(accountUsers map[string]*User, acco
if reload { if reload {
// reload cache once avoiding loops // reload cache once avoiding loops
err := am.cacheManager.Delete(am.ctx, accountID) data, err = am.refreshCache(accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
data, err = am.cacheManager.Get(am.ctx, accountID)
if err != nil {
return nil, err
}
userData = data.([]*idp.UserData)
} }
return userData, err return data, err
} }
// updateAccountDomainAttributes updates the account domain attributes and then, saves the account // updateAccountDomainAttributes updates the account domain attributes and then, saves the account
@@ -412,7 +522,7 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
} }
// we should register the account ID to this user's metadata in our IDP manager // we should register the account ID to this user's metadata in our IDP manager
err = am.updateIDPMetadata(claims.UserId, existingAcc.Id) err = am.addAccountIDToIDPAppMeta(claims.UserId, existingAcc)
if err != nil { if err != nil {
return err return err
} }
@@ -450,7 +560,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(
} }
} }
err = am.updateIDPMetadata(claims.UserId, account.Id) err = am.addAccountIDToIDPAppMeta(claims.UserId, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -458,7 +568,61 @@ func (am *DefaultAccountManager) handleNewUserAccount(
return account, nil return account, nil
} }
// GetAccountWithAuthorizationClaims retrievs an account using JWT Claims. // redeemInvite checks whether user has been invited and redeems the invite
func (am *DefaultAccountManager) redeemInvite(account *Account, userID string) error {
// only possible with the enabled IdP manager
if am.idpManager == nil {
log.Warnf("invites only work with enabled IdP manager")
return nil
}
user, err := am.lookupUserInCache(userID, account)
if err != nil {
return err
}
if user == nil {
return status.Errorf(codes.NotFound, "user %s not found in the IdP", userID)
}
if user.AppMetadata.WTPendingInvite {
log.Infof("redeeming invite for user %s account %s", userID, account.Id)
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
// Our job is to just reload cache.
go func() {
// set AppMetadata setting WTPendingInvite = false to indicate that it was redeemed
// it shouldn't be necessary since the IdP should redeem on login,
// but we make sure we do it if IdP is not capable of resetting teh flag.
_ = am.idpManager.UpdateUserAppMetadata(userID, idp.AppMetadata{WTAccountID: account.Id, WTPendingInvite: false}) //nolint
_, err = am.refreshCache(account.Id)
if err != nil {
log.Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
return
}
log.Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
}()
}
return nil
}
// GetAccountFromToken returns an account associated with this token
func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) {
account, err := am.getAccountWithAuthorizationClaims(claims)
if err != nil {
return nil, err
}
err = am.redeemInvite(account, claims.UserId)
if err != nil {
return nil, err
}
return account, nil
}
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
// if domain is of the PrivateCategory category, it will evaluate // if domain is of the PrivateCategory category, it will evaluate
// if account is new, existing or if there is another account with the same domain // if account is new, existing or if there is another account with the same domain
// //
@@ -475,12 +639,12 @@ func (am *DefaultAccountManager) handleNewUserAccount(
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes // Existing user + Existing account + Existing Indexed Domain -> Nothing changes
// //
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims( func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(
claims jwtclaims.AuthorizationClaims, claims jwtclaims.AuthorizationClaims,
) (*Account, error) { ) (*Account, error) {
// if Account ID is part of the claims // if Account ID is part of the claims
// it means that we've already classified the domain and user has an account // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain) return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" { } else if claims.AccountId != "" {
accountFromID, err := am.GetAccountById(claims.AccountId) accountFromID, err := am.GetAccountById(claims.AccountId)
@@ -520,6 +684,11 @@ func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims(
} }
} }
func isDomainValid(domain string) bool {
re := regexp.MustCompile(`^([a-z0-9]+(-[a-z0-9]+)*\.)+[a-z]{2,}$`)
return re.Match([]byte(domain))
}
// AccountExists checks whether account exists (returns true) or not (returns false) // AccountExists checks whether account exists (returns true) or not (returns false)
func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) { func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) {
am.mux.Lock() am.mux.Lock()
@@ -577,18 +746,20 @@ func newAccountWithId(accountId, userId, domain string) *Account {
peers := make(map[string]*Peer) peers := make(map[string]*Peer)
users := make(map[string]*User) users := make(map[string]*User)
routes := make(map[string]*route.Route) routes := make(map[string]*route.Route)
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
users[userId] = NewAdminUser(userId) users[userId] = NewAdminUser(userId)
log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key) log.Debugf("created new account %s with setup key %s", accountId, defaultKey.Key)
acc := &Account{ acc := &Account{
Id: accountId, Id: accountId,
SetupKeys: setupKeys, SetupKeys: setupKeys,
Network: network, Network: network,
Peers: peers, Peers: peers,
Users: users, Users: users,
CreatedBy: userId, CreatedBy: userId,
Domain: domain, Domain: domain,
Routes: routes, Routes: routes,
NameServerGroups: nameServersGroups,
} }
addAllGroup(acc) addAllGroup(acc)

View File

@@ -127,7 +127,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
} }
} }
func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) { func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims type initUserParams jwtclaims.AuthorizationClaims
type test struct { type test struct {
@@ -140,6 +140,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
expectedMSG string expectedMSG string
expectedUserRole UserRole expectedUserRole UserRole
expectedDomainCategory string expectedDomainCategory string
expectedDomain string
expectedPrimaryDomainStatus bool expectedPrimaryDomainStatus bool
expectedCreatedBy string expectedCreatedBy string
expectedUsers []string expectedUsers []string
@@ -168,6 +169,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
expectedMSG: "account IDs shouldn't match", expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin, expectedUserRole: UserRoleAdmin,
expectedDomainCategory: "", expectedDomainCategory: "",
expectedDomain: publicDomain,
expectedPrimaryDomainStatus: false, expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pub-domain-user", expectedCreatedBy: "pub-domain-user",
expectedUsers: []string{"pub-domain-user"}, expectedUsers: []string{"pub-domain-user"},
@@ -188,6 +190,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.NotEqual, testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match", expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin, expectedUserRole: UserRoleAdmin,
expectedDomain: unknownDomain,
expectedDomainCategory: "", expectedDomainCategory: "",
expectedPrimaryDomainStatus: false, expectedPrimaryDomainStatus: false,
expectedCreatedBy: "unknown-domain-user", expectedCreatedBy: "unknown-domain-user",
@@ -205,6 +208,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.NotEqual, testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match", expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin, expectedUserRole: UserRoleAdmin,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory, expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true, expectedPrimaryDomainStatus: true,
expectedCreatedBy: "pvt-domain-user", expectedCreatedBy: "pvt-domain-user",
@@ -227,6 +231,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.Equal, testingFunc: require.Equal,
expectedMSG: "account IDs should match", expectedMSG: "account IDs should match",
expectedUserRole: UserRoleUser, expectedUserRole: UserRoleUser,
expectedDomain: privateDomain,
expectedDomainCategory: PrivateCategory, expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true, expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId, expectedCreatedBy: defaultInitAccount.UserId,
@@ -244,6 +249,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.Equal, testingFunc: require.Equal,
expectedMSG: "account IDs should match", expectedMSG: "account IDs should match",
expectedUserRole: UserRoleAdmin, expectedUserRole: UserRoleAdmin,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory, expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true, expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId, expectedCreatedBy: defaultInitAccount.UserId,
@@ -262,12 +268,32 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testingFunc: require.Equal, testingFunc: require.Equal,
expectedMSG: "account IDs should match", expectedMSG: "account IDs should match",
expectedUserRole: UserRoleAdmin, expectedUserRole: UserRoleAdmin,
expectedDomain: defaultInitAccount.Domain,
expectedDomainCategory: PrivateCategory, expectedDomainCategory: PrivateCategory,
expectedPrimaryDomainStatus: true, expectedPrimaryDomainStatus: true,
expectedCreatedBy: defaultInitAccount.UserId, expectedCreatedBy: defaultInitAccount.UserId,
expectedUsers: []string{defaultInitAccount.UserId}, expectedUsers: []string{defaultInitAccount.UserId},
} }
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
testCase7 := test{
name: "User With Private Category And Empty Domain",
inputClaims: jwtclaims.AuthorizationClaims{
Domain: "",
UserId: "pvt-domain-user",
DomainCategory: PrivateCategory,
},
inputInitUserParams: defaultInitAccount,
testingFunc: require.NotEqual,
expectedMSG: "account IDs shouldn't match",
expectedUserRole: UserRoleAdmin,
expectedDomain: "",
expectedDomainCategory: "",
expectedPrimaryDomainStatus: false,
expectedCreatedBy: "pvt-domain-user",
expectedUsers: []string{"pvt-domain-user"},
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
manager, err := createManager(t) manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager") require.NoError(t, err, "unable to create account manager")
@@ -284,7 +310,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id testCase.inputClaims.AccountId = initAccount.Id
} }
account, err := manager.GetAccountWithAuthorizationClaims(testCase.inputClaims) account, err := manager.GetAccountFromToken(testCase.inputClaims)
require.NoError(t, err, "support function failed") require.NoError(t, err, "support function failed")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@@ -294,6 +320,7 @@ func TestDefaultAccountManager_GetAccountWithAuthorizationClaims(t *testing.T) {
require.EqualValues(t, testCase.expectedUserRole, account.Users[testCase.inputClaims.UserId].Role, "expected user role should match") require.EqualValues(t, testCase.expectedUserRole, account.Users[testCase.inputClaims.UserId].Role, "expected user role should match")
require.EqualValues(t, testCase.expectedDomainCategory, account.DomainCategory, "expected account domain category should match") require.EqualValues(t, testCase.expectedDomainCategory, account.DomainCategory, "expected account domain category should match")
require.EqualValues(t, testCase.expectedPrimaryDomainStatus, account.IsDomainPrimaryAccount, "expected account primary status should match") require.EqualValues(t, testCase.expectedPrimaryDomainStatus, account.IsDomainPrimaryAccount, "expected account primary status should match")
require.EqualValues(t, testCase.expectedDomain, account.Domain, "expected account domain should match")
}) })
} }
} }

View File

@@ -0,0 +1,52 @@
package server
import (
"fmt"
)
const (
// UserAlreadyExists indicates that user already exists
UserAlreadyExists ErrorType = 1
// AccountNotFound indicates that specified account hasn't been found
AccountNotFound ErrorType = iota
// PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled
PreconditionFailed ErrorType = iota
)
// ErrorType is a type of the Error
type ErrorType int32
// Error is an internal error
type Error struct {
errorType ErrorType
message string
}
// Type returns the Type of the error
func (e *Error) Type() ErrorType {
return e.errorType
}
// Error is an error string
func (e *Error) Error() string {
return e.message
}
// Errorf returns Error(errorType, fmt.Sprintf(format, a...)).
func Errorf(errorType ErrorType, format string, a ...interface{}) error {
return &Error{
errorType: errorType,
message: fmt.Sprintf(format, a...),
}
}
// FromError returns Error, true if the provided error is of type of Error. nil, false otherwise
func FromError(err error) (s *Error, ok bool) {
if err == nil {
return nil, true
}
if e, ok := err.(*Error); ok {
return e, true
}
return nil, false
}

View File

@@ -29,6 +29,7 @@ type FileStore struct {
PeerKeyId2DstRulesId map[string]map[string]struct{} `json:"-"` PeerKeyId2DstRulesId map[string]map[string]struct{} `json:"-"`
PeerKeyID2RouteIDs map[string]map[string]struct{} `json:"-"` PeerKeyID2RouteIDs map[string]map[string]struct{} `json:"-"`
AccountPrefix2RouteIDs map[string]map[string][]string `json:"-"` AccountPrefix2RouteIDs map[string]map[string][]string `json:"-"`
InstallationID string
// mutex to synchronise Store read/write operations // mutex to synchronise Store read/write operations
mux sync.Mutex `json:"-"` mux sync.Mutex `json:"-"`
@@ -415,8 +416,10 @@ func (s *FileStore) GetAccountPeers(accountId string) ([]*Peer, error) {
// GetAllAccounts returns all accounts // GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts() (all []*Account) { func (s *FileStore) GetAllAccounts() (all []*Account) {
s.mux.Lock()
defer s.mux.Unlock()
for _, a := range s.Accounts { for _, a := range s.Accounts {
all = append(all, a) all = append(all, a.Copy())
} }
return all return all
@@ -566,3 +569,18 @@ func (s *FileStore) GetRoutesByPrefix(accountID string, prefix netip.Prefix) ([]
return routes, nil return routes, nil
} }
// GetInstallationID returns the installation ID from the store
func (s *FileStore) GetInstallationID() string {
return s.InstallationID
}
// SaveInstallationID saves the installation ID
func (s *FileStore) SaveInstallationID(id string) error {
s.mux.Lock()
defer s.mux.Unlock()
s.InstallationID = id
return s.persist(s.storeFile)
}

View File

@@ -181,7 +181,7 @@ func (s *GRPCServer) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest)
return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err) return nil, status.Errorf(codes.Internal, "invalid jwt token, err: %v", err)
} }
claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience) claims := jwtclaims.ExtractClaimsWithToken(token, s.config.HttpConfig.AuthAudience)
_, err = s.accountManager.GetAccountWithAuthorizationClaims(claims) _, err = s.accountManager.GetAccountFromToken(claims)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) return nil, status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
} }

View File

@@ -3,3 +3,5 @@ generate:
models: true models: true
embedded-spec: false embedded-spec: false
output: types.gen.go output: types.gen.go
compatibility:
always-prefix-enum-values: true

View File

@@ -11,6 +11,6 @@ fi
old_pwd=$(pwd) old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0")) script_path=$(dirname $(realpath "$0"))
cd "$script_path" cd "$script_path"
go install github.com/deepmap/oapi-codegen/cmd/oapi-codegen@v1.11.0 go install github.com/deepmap/oapi-codegen/cmd/oapi-codegen@4a1477f6a8ba6ca8115cc23bb2fb67f0b9fca18e
oapi-codegen --config cfg.yaml openapi.yml oapi-codegen --config cfg.yaml openapi.yml
cd "$old_pwd" cd "$old_pwd"

View File

@@ -16,6 +16,8 @@ tags:
description: Interact with and view information about rules. description: Interact with and view information about rules.
- name: Routes - name: Routes
description: Interact with and view information about routes. description: Interact with and view information about routes.
- name: DNS
description: Interact with and view information about DNS configuration.
components: components:
schemas: schemas:
User: User:
@@ -33,6 +35,10 @@ components:
role: role:
description: User's NetBird account role description: User's NetBird account role
type: string type: string
status:
description: User's status
type: string
enum: [ "active","invited","disabled" ]
auto_groups: auto_groups:
description: Groups to auto-assign to peers registered by this user description: Groups to auto-assign to peers registered by this user
type: array type: array
@@ -44,6 +50,7 @@ components:
- name - name
- role - role
- auto_groups - auto_groups
- status
UserRequest: UserRequest:
type: object type: object
properties: properties:
@@ -58,6 +65,27 @@ components:
required: required:
- role - role
- auto_groups - auto_groups
UserCreateRequest:
type: object
properties:
role:
description: User's NetBird account role
type: string
email:
description: User's Email to send invite to
type: string
name:
description: User's full name
type: string
auto_groups:
description: Groups to auto-assign to peers registered by this user
type: array
items:
type: string
required:
- role
- auto_groups
- email
PeerMinimum: PeerMinimum:
type: object type: object
properties: properties:
@@ -96,20 +124,18 @@ components:
type: array type: array
items: items:
$ref: '#/components/schemas/GroupMinimum' $ref: '#/components/schemas/GroupMinimum'
activated_by:
description: Provides information of who activated the Peer. User or Setup Key
type: object
properties:
type:
type: string
value:
type: string
required:
- type
- value
ssh_enabled: ssh_enabled:
description: Indicates whether SSH server is enabled on this peer description: Indicates whether SSH server is enabled on this peer
type: boolean type: boolean
user_id:
description: User ID of the user that enrolled this peer
type: string
hostname:
description: Hostname of the machine
type: string
ui_version:
description: Peer's desktop UI version
type: string
required: required:
- ip - ip
- connected - connected
@@ -117,8 +143,8 @@ components:
- os - os
- version - version
- groups - groups
- activated_by
- ssh_enabled - ssh_enabled
- hostname
SetupKey: SetupKey:
type: object type: object
properties: properties:
@@ -375,6 +401,76 @@ components:
enum: [ "network","network_id","description","enabled","peer","metric","masquerade" ] enum: [ "network","network_id","description","enabled","peer","metric","masquerade" ]
required: required:
- path - path
Nameserver:
type: object
properties:
ip:
description: Nameserver IP
type: string
ns_type:
description: Nameserver Type
type: string
enum: ["udp"]
port:
description: Nameserver Port
type: integer
required:
- ip
- ns_type
- port
NameserverGroupRequest:
type: object
properties:
name:
description: Nameserver group name
type: string
maxLength: 40
minLength: 1
description:
description: Nameserver group description
type: string
nameservers:
description: Nameserver group
minLength: 1
maxLength: 2
type: array
items:
$ref: '#/components/schemas/Nameserver'
enabled:
description: Nameserver group status
type: boolean
groups:
description: Nameserver group tag groups
type: array
items:
type: string
required:
- name
- description
- nameservers
- enabled
- groups
NameserverGroup:
allOf:
- type: object
properties:
id:
description: Nameserver group ID
type: string
required:
- id
- $ref: '#/components/schemas/NameserverGroupRequest'
NameserverGroupPatchOperation:
allOf:
- $ref: '#/components/schemas/PatchMinimum'
- type: object
properties:
path:
description: Nameserver group field to update in form /<field>
type: string
enum: [ "name","description","enabled","groups","nameservers" ]
required:
- path
responses: responses:
not_found: not_found:
@@ -429,6 +525,33 @@ paths:
"$ref": "#/components/responses/forbidden" "$ref": "#/components/responses/forbidden"
'500': '500':
"$ref": "#/components/responses/internal_error" "$ref": "#/components/responses/internal_error"
/api/users/:
post:
summary: Create a User (invite)
tags: [ Users]
security:
- BearerAuth: [ ]
requestBody:
description: User invite information
content:
'application/json':
schema:
$ref: '#/components/schemas/UserCreateRequest'
responses:
'200':
description: A User object
content:
application/json:
schema:
$ref: '#/components/schemas/User'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/users/{id}: /api/users/{id}:
put: put:
summary: Update information about a User summary: Update information about a User
@@ -1252,3 +1375,173 @@ paths:
"$ref": "#/components/responses/forbidden" "$ref": "#/components/responses/forbidden"
'500': '500':
"$ref": "#/components/responses/internal_error" "$ref": "#/components/responses/internal_error"
/api/dns/nameservers:
get:
summary: Returns a list of all Nameserver Groups
tags: [ DNS ]
security:
- BearerAuth: [ ]
responses:
'200':
description: A JSON Array of Nameserver Groups
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/NameserverGroup'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Creates a Nameserver Group
tags: [ DNS ]
security:
- BearerAuth: [ ]
requestBody:
description: New Nameserver Groups request
content:
'application/json':
schema:
$ref: '#/components/schemas/NameserverGroupRequest'
responses:
'200':
description: A Nameserver Groups Object
content:
application/json:
schema:
$ref: '#/components/schemas/NameserverGroup'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/dns/nameservers/{id}:
get:
summary: Get information about a Nameserver Groups
tags: [ DNS ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Nameserver Group ID
responses:
'200':
description: A Nameserver Group object
content:
application/json:
schema:
$ref: '#/components/schemas/NameserverGroup'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
put:
summary: Update/Replace a Nameserver Group
tags: [ DNS ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Nameserver Group ID
requestBody:
description: Update Nameserver Group request
content:
application/json:
schema:
$ref: '#/components/schemas/NameserverGroupRequest'
responses:
'200':
description: A Nameserver Group object
content:
application/json:
schema:
$ref: '#/components/schemas/NameserverGroup'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
patch:
summary: Update information about a Nameserver Group
tags: [ DNS ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Nameserver Group ID
requestBody:
description: Update Nameserver Group request using a list of json patch objects
content:
'application/json':
schema:
type: array
items:
$ref: '#/components/schemas/NameserverGroupPatchOperation'
responses:
'200':
description: A Nameserver Group object
content:
application/json:
schema:
$ref: '#/components/schemas/NameserverGroup'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Nameserver Group
tags: [ DNS ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Nameserver Group ID
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"

View File

@@ -1,6 +1,6 @@
// Package api provides primitives to interact with the openapi HTTP API. // Package api provides primitives to interact with the openapi HTTP API.
// //
// Code generated by github.com/deepmap/oapi-codegen version v1.11.0 DO NOT EDIT. // Code generated by github.com/deepmap/oapi-codegen version v1.11.1-0.20220912230023-4a1477f6a8ba DO NOT EDIT.
package api package api
import ( import (
@@ -24,6 +24,27 @@ const (
GroupPatchOperationPathPeers GroupPatchOperationPath = "peers" GroupPatchOperationPathPeers GroupPatchOperationPath = "peers"
) )
// Defines values for NameserverNsType.
const (
NameserverNsTypeUdp NameserverNsType = "udp"
)
// Defines values for NameserverGroupPatchOperationOp.
const (
NameserverGroupPatchOperationOpAdd NameserverGroupPatchOperationOp = "add"
NameserverGroupPatchOperationOpRemove NameserverGroupPatchOperationOp = "remove"
NameserverGroupPatchOperationOpReplace NameserverGroupPatchOperationOp = "replace"
)
// Defines values for NameserverGroupPatchOperationPath.
const (
NameserverGroupPatchOperationPathDescription NameserverGroupPatchOperationPath = "description"
NameserverGroupPatchOperationPathEnabled NameserverGroupPatchOperationPath = "enabled"
NameserverGroupPatchOperationPathGroups NameserverGroupPatchOperationPath = "groups"
NameserverGroupPatchOperationPathName NameserverGroupPatchOperationPath = "name"
NameserverGroupPatchOperationPathNameservers NameserverGroupPatchOperationPath = "nameservers"
)
// Defines values for PatchMinimumOp. // Defines values for PatchMinimumOp.
const ( const (
PatchMinimumOpAdd PatchMinimumOp = "add" PatchMinimumOpAdd PatchMinimumOp = "add"
@@ -66,321 +87,427 @@ const (
RulePatchOperationPathSources RulePatchOperationPath = "sources" RulePatchOperationPathSources RulePatchOperationPath = "sources"
) )
// Defines values for UserStatus.
const (
UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled"
UserStatusInvited UserStatus = "invited"
)
// Group defines model for Group. // Group defines model for Group.
type Group struct { type Group struct {
// Group ID // Id Group ID
Id string `json:"id"` Id string `json:"id"`
// Group Name identifier // Name Group Name identifier
Name string `json:"name"` Name string `json:"name"`
// List of peers object // Peers List of peers object
Peers []PeerMinimum `json:"peers"` Peers []PeerMinimum `json:"peers"`
// Count of peers associated to the group // PeersCount Count of peers associated to the group
PeersCount int `json:"peers_count"` PeersCount int `json:"peers_count"`
} }
// GroupMinimum defines model for GroupMinimum. // GroupMinimum defines model for GroupMinimum.
type GroupMinimum struct { type GroupMinimum struct {
// Group ID // Id Group ID
Id string `json:"id"` Id string `json:"id"`
// Group Name identifier // Name Group Name identifier
Name string `json:"name"` Name string `json:"name"`
// Count of peers associated to the group // PeersCount Count of peers associated to the group
PeersCount int `json:"peers_count"` PeersCount int `json:"peers_count"`
} }
// GroupPatchOperation defines model for GroupPatchOperation. // GroupPatchOperation defines model for GroupPatchOperation.
type GroupPatchOperation struct { type GroupPatchOperation struct {
// Patch operation type // Op Patch operation type
Op GroupPatchOperationOp `json:"op"` Op GroupPatchOperationOp `json:"op"`
// Group field to update in form /<field> // Path Group field to update in form /<field>
Path GroupPatchOperationPath `json:"path"` Path GroupPatchOperationPath `json:"path"`
// Values to be applied // Value Values to be applied
Value []string `json:"value"` Value []string `json:"value"`
} }
// Patch operation type // GroupPatchOperationOp Patch operation type
type GroupPatchOperationOp string type GroupPatchOperationOp string
// Group field to update in form /<field> // GroupPatchOperationPath Group field to update in form /<field>
type GroupPatchOperationPath string type GroupPatchOperationPath string
// Nameserver defines model for Nameserver.
type Nameserver struct {
// Ip Nameserver IP
Ip string `json:"ip"`
// NsType Nameserver Type
NsType NameserverNsType `json:"ns_type"`
// Port Nameserver Port
Port int `json:"port"`
}
// NameserverNsType Nameserver Type
type NameserverNsType string
// NameserverGroup defines model for NameserverGroup.
type NameserverGroup struct {
// Description Nameserver group description
Description string `json:"description"`
// Enabled Nameserver group status
Enabled bool `json:"enabled"`
// Groups Nameserver group tag groups
Groups []string `json:"groups"`
// Id Nameserver group ID
Id string `json:"id"`
// Name Nameserver group name
Name string `json:"name"`
// Nameservers Nameserver group
Nameservers []Nameserver `json:"nameservers"`
}
// NameserverGroupPatchOperation defines model for NameserverGroupPatchOperation.
type NameserverGroupPatchOperation struct {
// Op Patch operation type
Op NameserverGroupPatchOperationOp `json:"op"`
// Path Nameserver group field to update in form /<field>
Path NameserverGroupPatchOperationPath `json:"path"`
// Value Values to be applied
Value []string `json:"value"`
}
// NameserverGroupPatchOperationOp Patch operation type
type NameserverGroupPatchOperationOp string
// NameserverGroupPatchOperationPath Nameserver group field to update in form /<field>
type NameserverGroupPatchOperationPath string
// NameserverGroupRequest defines model for NameserverGroupRequest.
type NameserverGroupRequest struct {
// Description Nameserver group description
Description string `json:"description"`
// Enabled Nameserver group status
Enabled bool `json:"enabled"`
// Groups Nameserver group tag groups
Groups []string `json:"groups"`
// Name Nameserver group name
Name string `json:"name"`
// Nameservers Nameserver group
Nameservers []Nameserver `json:"nameservers"`
}
// PatchMinimum defines model for PatchMinimum. // PatchMinimum defines model for PatchMinimum.
type PatchMinimum struct { type PatchMinimum struct {
// Patch operation type // Op Patch operation type
Op PatchMinimumOp `json:"op"` Op PatchMinimumOp `json:"op"`
// Values to be applied // Value Values to be applied
Value []string `json:"value"` Value []string `json:"value"`
} }
// Patch operation type // PatchMinimumOp Patch operation type
type PatchMinimumOp string type PatchMinimumOp string
// Peer defines model for Peer. // Peer defines model for Peer.
type Peer struct { type Peer struct {
// Provides information of who activated the Peer. User or Setup Key // Connected Peer to Management connection status
ActivatedBy struct {
Type string `json:"type"`
Value string `json:"value"`
} `json:"activated_by"`
// Peer to Management connection status
Connected bool `json:"connected"` Connected bool `json:"connected"`
// Groups that the peer belongs to // Groups Groups that the peer belongs to
Groups []GroupMinimum `json:"groups"` Groups []GroupMinimum `json:"groups"`
// Peer ID // Hostname Hostname of the machine
Hostname string `json:"hostname"`
// Id Peer ID
Id string `json:"id"` Id string `json:"id"`
// Peer's IP address // Ip Peer's IP address
Ip string `json:"ip"` Ip string `json:"ip"`
// Last time peer connected to Netbird's management service // LastSeen Last time peer connected to Netbird's management service
LastSeen time.Time `json:"last_seen"` LastSeen time.Time `json:"last_seen"`
// Peer's hostname // Name Peer's hostname
Name string `json:"name"` Name string `json:"name"`
// Peer's operating system and version // Os Peer's operating system and version
Os string `json:"os"` Os string `json:"os"`
// Indicates whether SSH server is enabled on this peer // SshEnabled Indicates whether SSH server is enabled on this peer
SshEnabled bool `json:"ssh_enabled"` SshEnabled bool `json:"ssh_enabled"`
// Peer's daemon or cli version // UiVersion Peer's desktop UI version
UiVersion *string `json:"ui_version,omitempty"`
// UserId User ID of the user that enrolled this peer
UserId *string `json:"user_id,omitempty"`
// Version Peer's daemon or cli version
Version string `json:"version"` Version string `json:"version"`
} }
// PeerMinimum defines model for PeerMinimum. // PeerMinimum defines model for PeerMinimum.
type PeerMinimum struct { type PeerMinimum struct {
// Peer ID // Id Peer ID
Id string `json:"id"` Id string `json:"id"`
// Peer's hostname // Name Peer's hostname
Name string `json:"name"` Name string `json:"name"`
} }
// Route defines model for Route. // Route defines model for Route.
type Route struct { type Route struct {
// Route description // Description Route description
Description string `json:"description"` Description string `json:"description"`
// Route status // Enabled Route status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Route Id // Id Route Id
Id string `json:"id"` Id string `json:"id"`
// Indicate if peer should masquerade traffic to this route's prefix // Masquerade Indicate if peer should masquerade traffic to this route's prefix
Masquerade bool `json:"masquerade"` Masquerade bool `json:"masquerade"`
// Route metric number. Lowest number has higher priority // Metric Route metric number. Lowest number has higher priority
Metric int `json:"metric"` Metric int `json:"metric"`
// Network range in CIDR format // Network Network range in CIDR format
Network string `json:"network"` Network string `json:"network"`
// Route network identifier, to group HA routes // NetworkId Route network identifier, to group HA routes
NetworkId string `json:"network_id"` NetworkId string `json:"network_id"`
// Network type indicating if it is IPv4 or IPv6 // NetworkType Network type indicating if it is IPv4 or IPv6
NetworkType string `json:"network_type"` NetworkType string `json:"network_type"`
// Peer Identifier associated with route // Peer Peer Identifier associated with route
Peer string `json:"peer"` Peer string `json:"peer"`
} }
// RoutePatchOperation defines model for RoutePatchOperation. // RoutePatchOperation defines model for RoutePatchOperation.
type RoutePatchOperation struct { type RoutePatchOperation struct {
// Patch operation type // Op Patch operation type
Op RoutePatchOperationOp `json:"op"` Op RoutePatchOperationOp `json:"op"`
// Route field to update in form /<field> // Path Route field to update in form /<field>
Path RoutePatchOperationPath `json:"path"` Path RoutePatchOperationPath `json:"path"`
// Values to be applied // Value Values to be applied
Value []string `json:"value"` Value []string `json:"value"`
} }
// Patch operation type // RoutePatchOperationOp Patch operation type
type RoutePatchOperationOp string type RoutePatchOperationOp string
// Route field to update in form /<field> // RoutePatchOperationPath Route field to update in form /<field>
type RoutePatchOperationPath string type RoutePatchOperationPath string
// RouteRequest defines model for RouteRequest. // RouteRequest defines model for RouteRequest.
type RouteRequest struct { type RouteRequest struct {
// Route description // Description Route description
Description string `json:"description"` Description string `json:"description"`
// Route status // Enabled Route status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Indicate if peer should masquerade traffic to this route's prefix // Masquerade Indicate if peer should masquerade traffic to this route's prefix
Masquerade bool `json:"masquerade"` Masquerade bool `json:"masquerade"`
// Route metric number. Lowest number has higher priority // Metric Route metric number. Lowest number has higher priority
Metric int `json:"metric"` Metric int `json:"metric"`
// Network range in CIDR format // Network Network range in CIDR format
Network string `json:"network"` Network string `json:"network"`
// Route network identifier, to group HA routes // NetworkId Route network identifier, to group HA routes
NetworkId string `json:"network_id"` NetworkId string `json:"network_id"`
// Peer Identifier associated with route // Peer Peer Identifier associated with route
Peer string `json:"peer"` Peer string `json:"peer"`
} }
// Rule defines model for Rule. // Rule defines model for Rule.
type Rule struct { type Rule struct {
// Rule friendly description // Description Rule friendly description
Description string `json:"description"` Description string `json:"description"`
// Rule destination groups // Destinations Rule destination groups
Destinations []GroupMinimum `json:"destinations"` Destinations []GroupMinimum `json:"destinations"`
// Rules status // Disabled Rules status
Disabled bool `json:"disabled"` Disabled bool `json:"disabled"`
// Rule flow, currently, only "bidirect" for bi-directional traffic is accepted // Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
Flow string `json:"flow"` Flow string `json:"flow"`
// Rule ID // Id Rule ID
Id string `json:"id"` Id string `json:"id"`
// Rule name identifier // Name Rule name identifier
Name string `json:"name"` Name string `json:"name"`
// Rule source groups // Sources Rule source groups
Sources []GroupMinimum `json:"sources"` Sources []GroupMinimum `json:"sources"`
} }
// RuleMinimum defines model for RuleMinimum. // RuleMinimum defines model for RuleMinimum.
type RuleMinimum struct { type RuleMinimum struct {
// Rule friendly description // Description Rule friendly description
Description string `json:"description"` Description string `json:"description"`
// Rules status // Disabled Rules status
Disabled bool `json:"disabled"` Disabled bool `json:"disabled"`
// Rule flow, currently, only "bidirect" for bi-directional traffic is accepted // Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
Flow string `json:"flow"` Flow string `json:"flow"`
// Rule name identifier // Name Rule name identifier
Name string `json:"name"` Name string `json:"name"`
} }
// RulePatchOperation defines model for RulePatchOperation. // RulePatchOperation defines model for RulePatchOperation.
type RulePatchOperation struct { type RulePatchOperation struct {
// Patch operation type // Op Patch operation type
Op RulePatchOperationOp `json:"op"` Op RulePatchOperationOp `json:"op"`
// Rule field to update in form /<field> // Path Rule field to update in form /<field>
Path RulePatchOperationPath `json:"path"` Path RulePatchOperationPath `json:"path"`
// Values to be applied // Value Values to be applied
Value []string `json:"value"` Value []string `json:"value"`
} }
// Patch operation type // RulePatchOperationOp Patch operation type
type RulePatchOperationOp string type RulePatchOperationOp string
// Rule field to update in form /<field> // RulePatchOperationPath Rule field to update in form /<field>
type RulePatchOperationPath string type RulePatchOperationPath string
// SetupKey defines model for SetupKey. // SetupKey defines model for SetupKey.
type SetupKey struct { type SetupKey struct {
// Setup key groups to auto-assign to peers registered with this key // AutoGroups Setup key groups to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// Setup Key expiration date // Expires Setup Key expiration date
Expires time.Time `json:"expires"` Expires time.Time `json:"expires"`
// Setup Key ID // Id Setup Key ID
Id string `json:"id"` Id string `json:"id"`
// Setup Key value // Key Setup Key value
Key string `json:"key"` Key string `json:"key"`
// Setup key last usage date // LastUsed Setup key last usage date
LastUsed time.Time `json:"last_used"` LastUsed time.Time `json:"last_used"`
// Setup key name identifier // Name Setup key name identifier
Name string `json:"name"` Name string `json:"name"`
// Setup key revocation status // Revoked Setup key revocation status
Revoked bool `json:"revoked"` Revoked bool `json:"revoked"`
// Setup key status, "valid", "overused","expired" or "revoked" // State Setup key status, "valid", "overused","expired" or "revoked"
State string `json:"state"` State string `json:"state"`
// Setup key type, one-off for single time usage and reusable // Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"` Type string `json:"type"`
// Setup key last update date // UpdatedAt Setup key last update date
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
// Usage count of setup key // UsedTimes Usage count of setup key
UsedTimes int `json:"used_times"` UsedTimes int `json:"used_times"`
// Setup key validity status // Valid Setup key validity status
Valid bool `json:"valid"` Valid bool `json:"valid"`
} }
// SetupKeyRequest defines model for SetupKeyRequest. // SetupKeyRequest defines model for SetupKeyRequest.
type SetupKeyRequest struct { type SetupKeyRequest struct {
// Setup key groups to auto-assign to peers registered with this key // AutoGroups Setup key groups to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// Expiration time in seconds // ExpiresIn Expiration time in seconds
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
// Setup Key name // Name Setup Key name
Name string `json:"name"` Name string `json:"name"`
// Setup key revocation status // Revoked Setup key revocation status
Revoked bool `json:"revoked"` Revoked bool `json:"revoked"`
// Setup key type, one-off for single time usage and reusable // Type Setup key type, one-off for single time usage and reusable
Type string `json:"type"` Type string `json:"type"`
} }
// User defines model for User. // User defines model for User.
type User struct { type User struct {
// Groups to auto-assign to peers registered by this user // AutoGroups Groups to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// User's email address // Email User's email address
Email string `json:"email"` Email string `json:"email"`
// User ID // Id User ID
Id string `json:"id"` Id string `json:"id"`
// User's name from idp provider // Name User's name from idp provider
Name string `json:"name"` Name string `json:"name"`
// User's NetBird account role // Role User's NetBird account role
Role string `json:"role"`
// Status User's status
Status UserStatus `json:"status"`
}
// UserStatus User's status
type UserStatus string
// UserCreateRequest defines model for UserCreateRequest.
type UserCreateRequest struct {
// AutoGroups Groups to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"`
// Email User's Email to send invite to
Email string `json:"email"`
// Name User's full name
Name *string `json:"name,omitempty"`
// Role User's NetBird account role
Role string `json:"role"` Role string `json:"role"`
} }
// UserRequest defines model for UserRequest. // UserRequest defines model for UserRequest.
type UserRequest struct { type UserRequest struct {
// Groups to auto-assign to peers registered by this user // AutoGroups Groups to auto-assign to peers registered by this user
AutoGroups []string `json:"auto_groups"` AutoGroups []string `json:"auto_groups"`
// User's NetBird account role // Role User's NetBird account role
Role string `json:"role"` Role string `json:"role"`
} }
// PatchApiDnsNameserversIdJSONBody defines parameters for PatchApiDnsNameserversId.
type PatchApiDnsNameserversIdJSONBody = []NameserverGroupPatchOperation
// PostApiGroupsJSONBody defines parameters for PostApiGroups. // PostApiGroupsJSONBody defines parameters for PostApiGroups.
type PostApiGroupsJSONBody struct { type PostApiGroupsJSONBody struct {
Name string `json:"name"` Name string `json:"name"`
@@ -402,28 +529,22 @@ type PutApiPeersIdJSONBody struct {
SshEnabled bool `json:"ssh_enabled"` SshEnabled bool `json:"ssh_enabled"`
} }
// PostApiRoutesJSONBody defines parameters for PostApiRoutes.
type PostApiRoutesJSONBody = RouteRequest
// PatchApiRoutesIdJSONBody defines parameters for PatchApiRoutesId. // PatchApiRoutesIdJSONBody defines parameters for PatchApiRoutesId.
type PatchApiRoutesIdJSONBody = []RoutePatchOperation type PatchApiRoutesIdJSONBody = []RoutePatchOperation
// PutApiRoutesIdJSONBody defines parameters for PutApiRoutesId.
type PutApiRoutesIdJSONBody = RouteRequest
// PostApiRulesJSONBody defines parameters for PostApiRules. // PostApiRulesJSONBody defines parameters for PostApiRules.
type PostApiRulesJSONBody struct { type PostApiRulesJSONBody struct {
// Rule friendly description // Description Rule friendly description
Description string `json:"description"` Description string `json:"description"`
Destinations *[]string `json:"destinations,omitempty"` Destinations *[]string `json:"destinations,omitempty"`
// Rules status // Disabled Rules status
Disabled bool `json:"disabled"` Disabled bool `json:"disabled"`
// Rule flow, currently, only "bidirect" for bi-directional traffic is accepted // Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
Flow string `json:"flow"` Flow string `json:"flow"`
// Rule name identifier // Name Rule name identifier
Name string `json:"name"` Name string `json:"name"`
Sources *[]string `json:"sources,omitempty"` Sources *[]string `json:"sources,omitempty"`
} }
@@ -433,29 +554,29 @@ type PatchApiRulesIdJSONBody = []RulePatchOperation
// PutApiRulesIdJSONBody defines parameters for PutApiRulesId. // PutApiRulesIdJSONBody defines parameters for PutApiRulesId.
type PutApiRulesIdJSONBody struct { type PutApiRulesIdJSONBody struct {
// Rule friendly description // Description Rule friendly description
Description string `json:"description"` Description string `json:"description"`
Destinations *[]string `json:"destinations,omitempty"` Destinations *[]string `json:"destinations,omitempty"`
// Rules status // Disabled Rules status
Disabled bool `json:"disabled"` Disabled bool `json:"disabled"`
// Rule flow, currently, only "bidirect" for bi-directional traffic is accepted // Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
Flow string `json:"flow"` Flow string `json:"flow"`
// Rule name identifier // Name Rule name identifier
Name string `json:"name"` Name string `json:"name"`
Sources *[]string `json:"sources,omitempty"` Sources *[]string `json:"sources,omitempty"`
} }
// PostApiSetupKeysJSONBody defines parameters for PostApiSetupKeys. // PostApiDnsNameserversJSONRequestBody defines body for PostApiDnsNameservers for application/json ContentType.
type PostApiSetupKeysJSONBody = SetupKeyRequest type PostApiDnsNameserversJSONRequestBody = NameserverGroupRequest
// PutApiSetupKeysIdJSONBody defines parameters for PutApiSetupKeysId. // PatchApiDnsNameserversIdJSONRequestBody defines body for PatchApiDnsNameserversId for application/json ContentType.
type PutApiSetupKeysIdJSONBody = SetupKeyRequest type PatchApiDnsNameserversIdJSONRequestBody = PatchApiDnsNameserversIdJSONBody
// PutApiUsersIdJSONBody defines parameters for PutApiUsersId. // PutApiDnsNameserversIdJSONRequestBody defines body for PutApiDnsNameserversId for application/json ContentType.
type PutApiUsersIdJSONBody = UserRequest type PutApiDnsNameserversIdJSONRequestBody = NameserverGroupRequest
// PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType. // PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType.
type PostApiGroupsJSONRequestBody PostApiGroupsJSONBody type PostApiGroupsJSONRequestBody PostApiGroupsJSONBody
@@ -470,13 +591,13 @@ type PutApiGroupsIdJSONRequestBody PutApiGroupsIdJSONBody
type PutApiPeersIdJSONRequestBody PutApiPeersIdJSONBody type PutApiPeersIdJSONRequestBody PutApiPeersIdJSONBody
// PostApiRoutesJSONRequestBody defines body for PostApiRoutes for application/json ContentType. // PostApiRoutesJSONRequestBody defines body for PostApiRoutes for application/json ContentType.
type PostApiRoutesJSONRequestBody = PostApiRoutesJSONBody type PostApiRoutesJSONRequestBody = RouteRequest
// PatchApiRoutesIdJSONRequestBody defines body for PatchApiRoutesId for application/json ContentType. // PatchApiRoutesIdJSONRequestBody defines body for PatchApiRoutesId for application/json ContentType.
type PatchApiRoutesIdJSONRequestBody = PatchApiRoutesIdJSONBody type PatchApiRoutesIdJSONRequestBody = PatchApiRoutesIdJSONBody
// PutApiRoutesIdJSONRequestBody defines body for PutApiRoutesId for application/json ContentType. // PutApiRoutesIdJSONRequestBody defines body for PutApiRoutesId for application/json ContentType.
type PutApiRoutesIdJSONRequestBody = PutApiRoutesIdJSONBody type PutApiRoutesIdJSONRequestBody = RouteRequest
// PostApiRulesJSONRequestBody defines body for PostApiRules for application/json ContentType. // PostApiRulesJSONRequestBody defines body for PostApiRules for application/json ContentType.
type PostApiRulesJSONRequestBody PostApiRulesJSONBody type PostApiRulesJSONRequestBody PostApiRulesJSONBody
@@ -488,10 +609,13 @@ type PatchApiRulesIdJSONRequestBody = PatchApiRulesIdJSONBody
type PutApiRulesIdJSONRequestBody PutApiRulesIdJSONBody type PutApiRulesIdJSONRequestBody PutApiRulesIdJSONBody
// PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType. // PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType.
type PostApiSetupKeysJSONRequestBody = PostApiSetupKeysJSONBody type PostApiSetupKeysJSONRequestBody = SetupKeyRequest
// PutApiSetupKeysIdJSONRequestBody defines body for PutApiSetupKeysId for application/json ContentType. // PutApiSetupKeysIdJSONRequestBody defines body for PutApiSetupKeysId for application/json ContentType.
type PutApiSetupKeysIdJSONRequestBody = PutApiSetupKeysIdJSONBody type PutApiSetupKeysIdJSONRequestBody = SetupKeyRequest
// PostApiUsersJSONRequestBody defines body for PostApiUsers for application/json ContentType.
type PostApiUsersJSONRequestBody = UserCreateRequest
// PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType. // PutApiUsersIdJSONRequestBody defines body for PutApiUsersId for application/json ContentType.
type PutApiUsersIdJSONRequestBody = PutApiUsersIdJSONBody type PutApiUsersIdJSONRequestBody = UserRequest

View File

@@ -67,14 +67,14 @@ func initGroupTestData(groups ...*server.Group) *Groups {
} }
return nil, fmt.Errorf("peer not found") return nil, fmt.Errorf("peer not found")
}, },
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
Peers: TestPeers, Peers: TestPeers,
Groups: map[string]*server.Group{ Groups: map[string]*server.Group{
"id-existed": &server.Group{ID: "id-existed", Peers: []string{"A", "B"}}, "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
"id-all": &server.Group{ID: "id-all", Name: "All"}}, "id-all": {ID: "id-all", Name: "All"}},
}, nil }, nil
}, },
}, },

View File

@@ -34,12 +34,14 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
keysHandler := NewSetupKeysHandler(accountManager, authAudience) keysHandler := NewSetupKeysHandler(accountManager, authAudience)
userHandler := NewUserHandler(accountManager, authAudience) userHandler := NewUserHandler(accountManager, authAudience)
routesHandler := NewRoutes(accountManager, authAudience) routesHandler := NewRoutes(accountManager, authAudience)
nameserversHandler := NewNameservers(accountManager, authAudience)
apiHandler.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/api/peers", peersHandler.GetPeers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer). apiHandler.HandleFunc("/api/peers/{id}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS") Methods("GET", "PUT", "DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/api/users", userHandler.GetUsers).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS") apiHandler.HandleFunc("/api/users/{id}", userHandler.UpdateUser).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/users", userHandler.CreateUserHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/api/setup-keys", keysHandler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS") apiHandler.HandleFunc("/api/setup-keys", keysHandler.CreateSetupKeyHandler).Methods("POST", "OPTIONS")
@@ -67,6 +69,13 @@ func APIHandler(accountManager s.AccountManager, authIssuer string, authAudience
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.GetRouteHandler).Methods("GET", "OPTIONS") apiHandler.HandleFunc("/api/routes/{id}", routesHandler.GetRouteHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/routes/{id}", routesHandler.DeleteRouteHandler).Methods("DELETE", "OPTIONS") apiHandler.HandleFunc("/api/routes/{id}", routesHandler.DeleteRouteHandler).Methods("DELETE", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers", nameserversHandler.GetAllNameserversHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers", nameserversHandler.CreateNameserverGroupHandler).Methods("POST", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.UpdateNameserverGroupHandler).Methods("PUT", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.PatchNameserverGroupHandler).Methods("PATCH", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.GetNameserverGroupHandler).Methods("GET", "OPTIONS")
apiHandler.HandleFunc("/api/dns/nameservers/{id}", nameserversHandler.DeleteNameserverGroupHandler).Methods("DELETE", "OPTIONS")
return apiHandler, nil return apiHandler, nil
} }

View File

@@ -0,0 +1,286 @@
package http
import (
"encoding/json"
"fmt"
"github.com/gorilla/mux"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"net/http"
)
// Nameservers is the nameserver group handler of the account
type Nameservers struct {
jwtExtractor jwtclaims.ClaimsExtractor
accountManager server.AccountManager
authAudience string
}
// NewNameservers returns a new instance of Nameservers handler
func NewNameservers(accountManager server.AccountManager, authAudience string) *Nameservers {
return &Nameservers{
accountManager: accountManager,
authAudience: authAudience,
jwtExtractor: *jwtclaims.NewClaimsExtractor(nil),
}
}
// GetAllNameserversHandler returns the list of nameserver groups for the account
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroups, err := h.accountManager.ListNameServerGroups(account.Id)
if err != nil {
toHTTPError(err, w)
return
}
apiNameservers := make([]*api.NameserverGroup, 0)
for _, r := range nsGroups {
apiNameservers = append(apiNameservers, toNameserverGroupResponse(r))
}
writeJSONObject(w, apiNameservers)
}
// CreateNameserverGroupHandler handles nameserver group creation request
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
var req api.PostApiDnsNameserversJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
nsList, err := toServerNSList(req.Nameservers)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(account.Id, req.Name, req.Description, nsList, req.Groups, req.Enabled)
if err != nil {
toHTTPError(err, w)
return
}
resp := toNameserverGroupResponse(nsGroup)
writeJSONObject(w, &resp)
}
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
return
}
var req api.PutApiDnsNameserversIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
nsList, err := toServerNSList(req.Nameservers)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
updatedNSGroup := &nbdns.NameServerGroup{
ID: nsGroupID,
Name: req.Name,
Description: req.Description,
NameServers: nsList,
Groups: req.Groups,
Enabled: req.Enabled,
}
err = h.accountManager.SaveNameServerGroup(account.Id, updatedNSGroup)
if err != nil {
toHTTPError(err, w)
return
}
resp := toNameserverGroupResponse(updatedNSGroup)
writeJSONObject(w, &resp)
}
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
return
}
var req api.PatchApiDnsNameserversIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
var operations []server.NameServerGroupUpdateOperation
for _, patch := range req {
if patch.Op != api.NameserverGroupPatchOperationOpReplace {
http.Error(w, fmt.Sprintf("nameserver groups only accepts replace operations, got %s", patch.Op),
http.StatusBadRequest)
return
}
switch patch.Path {
case api.NameserverGroupPatchOperationPathName:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupName,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathDescription:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupDescription,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathNameservers:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupNameServers,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathGroups:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupGroups,
Values: patch.Value,
})
case api.NameserverGroupPatchOperationPathEnabled:
operations = append(operations, server.NameServerGroupUpdateOperation{
Type: server.UpdateNameServerGroupEnabled,
Values: patch.Value,
})
default:
http.Error(w, "invalid patch path", http.StatusBadRequest)
return
}
}
updatedNSGroup, err := h.accountManager.UpdateNameServerGroup(account.Id, nsGroupID, operations)
if err != nil {
toHTTPError(err, w)
return
}
resp := toNameserverGroupResponse(updatedNSGroup)
writeJSONObject(w, &resp)
}
// DeleteNameserverGroupHandler handles nameserver group deletion request
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
return
}
err = h.accountManager.DeleteNameServerGroup(account.Id, nsGroupID)
if err != nil {
toHTTPError(err, w)
return
}
writeJSONObject(w, "")
}
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
nsGroupID := mux.Vars(r)["id"]
if len(nsGroupID) == 0 {
http.Error(w, "invalid nameserver group ID", http.StatusBadRequest)
return
}
nsGroup, err := h.accountManager.GetNameServerGroup(account.Id, nsGroupID)
if err != nil {
toHTTPError(err, w)
return
}
resp := toNameserverGroupResponse(nsGroup)
writeJSONObject(w, &resp)
}
func toServerNSList(apiNSList []api.Nameserver) ([]nbdns.NameServer, error) {
var nsList []nbdns.NameServer
for _, apiNS := range apiNSList {
parsed, err := nbdns.ParseNameServerURL(fmt.Sprintf("%s://%s:%d", apiNS.NsType, apiNS.Ip, apiNS.Port))
if err != nil {
return nil, err
}
nsList = append(nsList, parsed)
}
return nsList, nil
}
func toNameserverGroupResponse(serverNSGroup *nbdns.NameServerGroup) *api.NameserverGroup {
var nsList []api.Nameserver
for _, ns := range serverNSGroup.NameServers {
apiNS := api.Nameserver{
Ip: ns.IP.String(),
NsType: api.NameserverNsType(ns.NSType.String()),
Port: ns.Port,
}
nsList = append(nsList, apiNS)
}
return &api.NameserverGroup{
Id: serverNSGroup.ID,
Name: serverNSGroup.Name,
Description: serverNSGroup.Description,
Groups: serverNSGroup.Groups,
Nameservers: nsList,
Enabled: serverNSGroup.Enabled,
}
}

View File

@@ -0,0 +1,287 @@
package http
import (
"bytes"
"encoding/json"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"io"
"net/http"
"net/http/httptest"
"net/netip"
"testing"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
)
const (
existingNSGroupID = "existingNSGroupID"
notFoundNSGroupID = "notFoundNSGroupID"
testNSGroupAccountID = "test_id"
)
var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
}
var baseExistingNSGroup = &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: "super",
Description: "super",
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{"testing"},
Enabled: true,
}
func initNameserversTestData() *Nameservers {
return &Nameservers{
accountManager: &mock_server.MockAccountManager{
GetNameServerGroupFunc: func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if nsGroupID == existingNSGroupID {
return baseExistingNSGroup.Copy(), nil
}
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
},
CreateNameServerGroupFunc: func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) {
return &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: name,
Description: description,
NameServers: nameServerList,
Groups: groups,
Enabled: enabled,
}, nil
},
DeleteNameServerGroupFunc: func(accountID, nsGroupID string) error {
return nil
},
SaveNameServerGroupFunc: func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error {
if nsGroupToSave.ID == existingNSGroupID {
return nil
}
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
UpdateNameServerGroupFunc: func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
nsGroupToUpdate := baseExistingNSGroup.Copy()
if nsGroupID != nsGroupToUpdate.ID {
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
}
for _, operation := range operations {
switch operation.Type {
case server.UpdateNameServerGroupName:
nsGroupToUpdate.Name = operation.Values[0]
case server.UpdateNameServerGroupDescription:
nsGroupToUpdate.Description = operation.Values[0]
case server.UpdateNameServerGroupNameServers:
var parsedNSList []nbdns.NameServer
for _, nsURL := range operation.Values {
parsed, err := nbdns.ParseNameServerURL(nsURL)
if err != nil {
return nil, err
}
parsedNSList = append(parsedNSList, parsed)
}
nsGroupToUpdate.NameServers = parsedNSList
}
}
return nsGroupToUpdate, nil
},
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
return testingNSAccount, nil
},
},
authAudience: "",
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: testNSGroupAccountID,
}
},
},
}
}
func TestNameserversHandlers(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
expectedNSGroup *api.NameserverGroup
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "Get Existing Nameserver Group",
requestType: http.MethodGet,
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: toNameserverGroupResponse(baseExistingNSGroup),
},
{
name: "Get Not Existing Nameserver Group",
requestType: http.MethodGet,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
expectedStatus: http.StatusNotFound,
},
{
name: "POST OK",
requestType: http.MethodPost,
requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: &api.NameserverGroup{
Id: existingNSGroupID,
Name: "name",
Description: "Post",
Nameservers: []api.Nameserver{
{
Ip: "1.1.1.1",
NsType: "udp",
Port: 53,
},
},
Groups: []string{"group"},
Enabled: true,
},
},
{
name: "POST Invalid Nameserver",
requestType: http.MethodPost,
requestPath: "/api/dns/nameservers",
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1000\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "PUT OK",
requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: &api.NameserverGroup{
Id: existingNSGroupID,
Name: "name",
Description: "Post",
Nameservers: []api.Nameserver{
{
Ip: "1.1.1.1",
NsType: "udp",
Port: 53,
},
},
Groups: []string{"group"},
Enabled: true,
},
},
{
name: "PUT Not Existing Nameserver Group",
requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"1.1.1.1\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
{
name: "PUT Invalid Nameserver",
requestType: http.MethodPut,
requestPath: "/api/dns/nameservers/" + notFoundNSGroupID,
requestBody: bytes.NewBuffer(
[]byte("{\"name\":\"name\",\"Description\":\"Post\",\"nameservers\":[{\"ip\":\"100\",\"ns_type\":\"udp\",\"port\":53}],\"groups\":[\"group\"],\"enabled\":true}")),
expectedStatus: http.StatusBadRequest,
expectedBody: false,
},
{
name: "PATCH OK",
requestType: http.MethodPatch,
requestPath: "/api/dns/nameservers/" + existingNSGroupID,
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"description\",\"value\":[\"NewDesc\"]}]"),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedNSGroup: &api.NameserverGroup{
Id: existingNSGroupID,
Name: baseExistingNSGroup.Name,
Description: "NewDesc",
Nameservers: toNameserverGroupResponse(baseExistingNSGroup).Nameservers,
Groups: baseExistingNSGroup.Groups,
Enabled: baseExistingNSGroup.Enabled,
},
},
{
name: "PATCH Invalid Nameserver Group OK",
requestType: http.MethodPatch,
requestPath: "/api/dns/nameservers/" + notFoundRouteID,
requestBody: bytes.NewBufferString("[{\"op\":\"replace\",\"path\":\"description\",\"value\":[\"NewDesc\"]}]"),
expectedStatus: http.StatusNotFound,
expectedBody: false,
},
}
p := initNameserversTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/dns/nameservers/{id}", p.GetNameserverGroupHandler).Methods("GET")
router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroupHandler).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{id}", p.DeleteNameserverGroupHandler).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{id}", p.UpdateNameserverGroupHandler).Methods("PUT")
router.HandleFunc("/api/dns/nameservers/{id}", p.PatchNameserverGroupHandler).Methods("PATCH")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
got := &api.NameserverGroup{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, tc.expectedNSGroup, got)
})
}
}

View File

@@ -11,7 +11,7 @@ import (
"net/http" "net/http"
) )
//Peers is a handler that returns peers of the account // Peers is a handler that returns peers of the account
type Peers struct { type Peers struct {
accountManager server.AccountManager accountManager server.AccountManager
authAudience string authAudience string
@@ -144,5 +144,8 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
Version: peer.Meta.WtVersion, Version: peer.Meta.WtVersion,
Groups: groupsInfo, Groups: groupsInfo,
SshEnabled: peer.SSHEnabled, SshEnabled: peer.SSHEnabled,
Hostname: peer.Meta.Hostname,
UserId: &peer.UserID,
UiVersion: &peer.Meta.UIVersion,
} }
} }

View File

@@ -19,7 +19,7 @@ import (
func initTestMetaData(peer ...*server.Peer) *Peers { func initTestMetaData(peer ...*server.Peer) *Peers {
return &Peers{ return &Peers{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",

View File

@@ -123,7 +123,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
var req api.PutApiRoutesIdJSONBody var req api.PutApiRoutesIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil { if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return

View File

@@ -120,7 +120,7 @@ func initRoutesTestData() *Routes {
} }
return routeToUpdate, nil return routeToUpdate, nil
}, },
GetAccountWithAuthorizationClaimsFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, error) {
return testingAccount, nil return testingAccount, nil
}, },
}, },

View File

@@ -66,14 +66,14 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
} }
return &rule, nil return &rule, nil
}, },
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{ return &server.Account{
Id: claims.AccountId, Id: claims.AccountId,
Domain: "hotmail.com", Domain: "hotmail.com",
Rules: map[string]*server.Rule{"id-existed": &server.Rule{ID: "id-existed"}}, Rules: map[string]*server.Rule{"id-existed": &server.Rule{ID: "id-existed"}},
Groups: map[string]*server.Group{ Groups: map[string]*server.Group{
"F": &server.Group{ID: "F"}, "F": {ID: "F"},
"G": &server.Group{ID: "G"}, "G": {ID: "G"},
}, },
}, nil }, nil
}, },

View File

@@ -31,7 +31,7 @@ const (
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys { func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys {
return &SetupKeys{ return &SetupKeys{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{ return &server.Account{
Id: testAccountID, Id: testAccountID,
Domain: "hotmail.com", Domain: "hotmail.com",

View File

@@ -5,12 +5,11 @@ import (
"fmt" "fmt"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/api"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"net/http" "net/http"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
) )
@@ -82,6 +81,50 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
} }
// CreateUserHandler creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "", http.StatusNotFound)
}
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
}
req := &api.PostApiUsersJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown {
http.Error(w, "unknown user role "+req.Role, http.StatusBadRequest)
return
}
newUser, err := h.accountManager.CreateUser(account.Id, &server.UserInfo{
Email: req.Email,
Name: *req.Name,
Role: req.Role,
AutoGroups: req.AutoGroups,
})
if err != nil {
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.UserAlreadyExists:
http.Error(w, "You can't invite users with an existing NetBird account.", http.StatusPreconditionFailed)
return
default:
}
}
http.Error(w, "failed to invite", http.StatusInternalServerError)
return
}
writeJSONObject(w, toUserResponse(newUser))
}
// GetUsers returns a list of users of the account this user belongs to. // GetUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager. // It also gathers additional user data (like email and name) from the IDP manager.
func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) { func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
@@ -101,7 +144,7 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
return return
} }
users := []*api.User{} users := make([]*api.User, 0)
for _, r := range data { for _, r := range data {
users = append(users, toUserResponse(r)) users = append(users, toUserResponse(r))
} }
@@ -116,11 +159,22 @@ func toUserResponse(user *server.UserInfo) *api.User {
autoGroups = []string{} autoGroups = []string{}
} }
var userStatus api.UserStatus
switch user.Status {
case "active":
userStatus = api.UserStatusActive
case "invited":
userStatus = api.UserStatusInvited
default:
userStatus = api.UserStatusDisabled
}
return &api.User{ return &api.User{
Id: user.ID, Id: user.ID,
Name: user.Name, Name: user.Name,
Email: user.Email, Email: user.Email,
Role: user.Role, Role: user.Role,
AutoGroups: autoGroups, AutoGroups: autoGroups,
Status: userStatus,
} }
} }

View File

@@ -16,7 +16,7 @@ import (
func initUsers(user ...*server.User) *UserHandler { func initUsers(user ...*server.User) *UserHandler {
return &UserHandler{ return &UserHandler{
accountManager: &mock_server.MockAccountManager{ accountManager: &mock_server.MockAccountManager{
GetAccountWithAuthorizationClaimsFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
users := make(map[string]*server.User, 0) users := make(map[string]*server.User, 0)
for _, u := range user { for _, u := range user {
users[u.Id] = u users[u.Id] = u

View File

@@ -6,11 +6,14 @@ import (
"fmt" "fmt"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http" "net/http"
"time" "time"
) )
//writeJSONObject simply writes object to the HTTP reponse in JSON format // writeJSONObject simply writes object to the HTTP reponse in JSON format
func writeJSONObject(w http.ResponseWriter, obj interface{}) { func writeJSONObject(w http.ResponseWriter, obj interface{}) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json; charset=UTF-8") w.Header().Set("Content-Type", "application/json; charset=UTF-8")
@@ -21,7 +24,7 @@ func writeJSONObject(w http.ResponseWriter, obj interface{}) {
} }
} }
//Duration is used strictly for JSON requests/responses due to duration marshalling issues // Duration is used strictly for JSON requests/responses due to duration marshalling issues
type Duration struct { type Duration struct {
time.Duration time.Duration
} }
@@ -57,10 +60,32 @@ func getJWTAccount(accountManager server.AccountManager,
jwtClaims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience) jwtClaims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
account, err := accountManager.GetAccountWithAuthorizationClaims(jwtClaims) account, err := accountManager.GetAccountFromToken(jwtClaims)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err) return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
} }
return account, nil return account, nil
} }
func toHTTPError(err error, w http.ResponseWriter) {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.Internal {
http.Error(w, errStatus.String(), http.StatusInternalServerError)
return
}
if ok && errStatus.Code() == codes.NotFound {
http.Error(w, errStatus.String(), http.StatusNotFound)
return
}
if ok && errStatus.Code() == codes.InvalidArgument {
http.Error(w, errStatus.String(), http.StatusBadRequest)
return
}
unhandledMSG := fmt.Sprintf("got unhandled error code, error: %s", errStatus.String())
log.Error(unhandledMSG)
http.Error(w, unhandledMSG, http.StatusInternalServerError)
}

View File

@@ -7,7 +7,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
@@ -54,6 +53,16 @@ type Auth0Credentials struct {
mux sync.Mutex mux sync.Mutex
} }
// createUserRequest is a user create request
type createUserRequest struct {
Email string `json:"email"`
Name string `json:"name"`
AppMeta AppMetadata `json:"app_metadata"`
Connection string `json:"connection"`
Password string `json:"password"`
VerifyEmail bool `json:"verify_email"`
}
// userExportJobRequest is a user export request struct // userExportJobRequest is a user export request struct
type userExportJobRequest struct { type userExportJobRequest struct {
Format string `json:"format"` Format string `json:"format"`
@@ -87,12 +96,13 @@ type userExportJobStatusResponse struct {
// auth0Profile represents an Auth0 user profile response // auth0Profile represents an Auth0 user profile response
type auth0Profile struct { type auth0Profile struct {
AccountID string `json:"wt_account_id"` AccountID string `json:"wt_account_id"`
UserID string `json:"user_id"` PendingInvite bool `json:"wt_pending_invite"`
Name string `json:"name"` UserID string `json:"user_id"`
Email string `json:"email"` Name string `json:"name"`
CreatedAt string `json:"created_at"` Email string `json:"email"`
LastLogin string `json:"last_login"` CreatedAt string `json:"created_at"`
LastLogin string `json:"last_login"`
} }
// NewAuth0Manager creates a new instance of the Auth0Manager // NewAuth0Manager creates a new instance of the Auth0Manager
@@ -172,7 +182,7 @@ func (c *Auth0Credentials) requestJWTToken() (*http.Response, error) {
// parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds // parseRequestJWTResponse parses jwt raw response body and extracts token and expires in seconds
func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) { func (c *Auth0Credentials) parseRequestJWTResponse(rawBody io.ReadCloser) (JWTToken, error) {
jwtToken := JWTToken{} jwtToken := JWTToken{}
body, err := ioutil.ReadAll(rawBody) body, err := io.ReadAll(rawBody)
if err != nil { if err != nil {
return jwtToken, err return jwtToken, err
} }
@@ -230,7 +240,7 @@ func (c *Auth0Credentials) Authenticate() (JWTToken, error) {
return c.jwtToken, nil return c.jwtToken, nil
} }
func batchRequestUsersURL(authIssuer, accountID string, page int) (string, url.Values, error) { func batchRequestUsersURL(authIssuer, accountID string, page int, perPage int) (string, url.Values, error) {
u, err := url.Parse(authIssuer + "/api/v2/users") u, err := url.Parse(authIssuer + "/api/v2/users")
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@@ -238,6 +248,7 @@ func batchRequestUsersURL(authIssuer, accountID string, page int) (string, url.V
q := u.Query() q := u.Query()
q.Set("page", strconv.Itoa(page)) q.Set("page", strconv.Itoa(page))
q.Set("search_engine", "v3") q.Set("search_engine", "v3")
q.Set("per_page", strconv.Itoa(perPage))
q.Set("q", "app_metadata.wt_account_id:"+accountID) q.Set("q", "app_metadata.wt_account_id:"+accountID)
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
@@ -259,8 +270,9 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
// https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations // https://auth0.com/docs/manage-users/user-search/retrieve-users-with-get-users-endpoint#limitations
// auth0 limitation of 1000 users via this endpoint // auth0 limitation of 1000 users via this endpoint
resultsPerPage := 50
for page := 0; page < 20; page++ { for page := 0; page < 20; page++ {
reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page) reqURL, query, err := batchRequestUsersURL(am.authIssuer, accountID, page, resultsPerPage)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -283,30 +295,31 @@ func (am *Auth0Manager) GetAccount(accountID string) ([]*UserData, error) {
return nil, err return nil, err
} }
if res.StatusCode != 200 {
return nil, fmt.Errorf("failed requesting user data from IdP %s", string(body))
}
var batch []UserData var batch []UserData
err = json.Unmarshal(body, &batch) err = json.Unmarshal(body, &batch)
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("requested batch; %v", batch) log.Debugf("returned user batch for accountID %s on page %d, %v", accountID, page, batch)
err = res.Body.Close() err = res.Body.Close()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if res.StatusCode != 200 {
return nil, fmt.Errorf("unable to request UserData from auth0, statusCode %d", res.StatusCode)
}
if len(batch) == 0 {
return list, nil
}
for user := range batch { for user := range batch {
list = append(list, &batch[user]) list = append(list, &batch[user])
} }
if len(batch) == 0 || len(batch) < resultsPerPage {
log.Debugf("finished loading users for accountID %s", accountID)
return list, nil
}
} }
return list, nil return list, nil
@@ -367,14 +380,12 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
reqURL := am.authIssuer + "/api/v2/users/" + userID reqURL := am.authIssuer + "/api/v2/users/" + userID
data, err := am.helper.Marshal(appMetadata) data, err := am.helper.Marshal(map[string]any{"app_metadata": appMetadata})
if err != nil { if err != nil {
return err return err
} }
payloadString := fmt.Sprintf("{\"app_metadata\": %s}", string(data)) payload := strings.NewReader(string(data))
payload := strings.NewReader(payloadString)
req, err := http.NewRequest("PATCH", reqURL, payload) req, err := http.NewRequest("PATCH", reqURL, payload)
if err != nil { if err != nil {
@@ -383,7 +394,7 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken) req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json") req.Header.Add("content-type", "application/json")
log.Debugf("updating metadata for user %s", userID) log.Debugf("updating IdP metadata for user %s", userID)
res, err := am.httpClient.Do(req) res, err := am.httpClient.Do(req)
if err != nil { if err != nil {
@@ -404,6 +415,27 @@ func (am *Auth0Manager) UpdateUserAppMetadata(userID string, appMetadata AppMeta
return nil return nil
} }
func buildCreateUserRequestPayload(email string, name string, accountID string) (string, error) {
req := &createUserRequest{
Email: email,
Name: name,
AppMeta: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: true,
},
Connection: "Username-Password-Authentication",
Password: GeneratePassword(8, 1, 1, 1),
VerifyEmail: true,
}
str, err := json.Marshal(req)
if err != nil {
return "", err
}
return string(str), nil
}
func buildUserExportRequest() (string, error) { func buildUserExportRequest() (string, error) {
req := &userExportJobRequest{} req := &userExportJobRequest{}
fields := make([]map[string]string, 0) fields := make([]map[string]string, 0)
@@ -417,6 +449,11 @@ func buildUserExportRequest() (string, error) {
"export_as": "wt_account_id", "export_as": "wt_account_id",
}) })
fields = append(fields, map[string]string{
"name": "app_metadata.wt_pending_invite",
"export_as": "wt_pending_invite",
})
req.Format = "json" req.Format = "json"
req.Fields = fields req.Fields = fields
@@ -428,28 +465,39 @@ func buildUserExportRequest() (string, error) {
return string(str), nil return string(str), nil
} }
// GetAllAccounts gets all registered accounts with corresponding user data. func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
// It returns a list of users indexed by accountID.
func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
jwtToken, err := am.credentials.Authenticate() jwtToken, err := am.credentials.Authenticate()
if err != nil { if err != nil {
return nil, err return nil, err
} }
reqURL := am.authIssuer + "/api/v2/jobs/users-exports" reqURL := am.authIssuer + endpoint
payload := strings.NewReader(payloadStr)
req, err := http.NewRequest("POST", reqURL, payload)
if err != nil {
return nil, err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
return req, nil
}
// GetAllAccounts gets all registered accounts with corresponding user data.
// It returns a list of users indexed by accountID.
func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
payloadString, err := buildUserExportRequest() payloadString, err := buildUserExportRequest()
if err != nil { if err != nil {
return nil, err return nil, err
} }
payload := strings.NewReader(payloadString)
exportJobReq, err := http.NewRequest("POST", reqURL, payload) exportJobReq, err := am.createPostRequest("/api/v2/jobs/users-exports", payloadString)
if err != nil { if err != nil {
return nil, err return nil, err
} }
exportJobReq.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
exportJobReq.Header.Add("content-type", "application/json")
jobResp, err := am.httpClient.Do(exportJobReq) jobResp, err := am.httpClient.Do(exportJobReq)
if err != nil { if err != nil {
@@ -469,7 +517,7 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
var exportJobResp userExportJobResponse var exportJobResp userExportJobResponse
body, err := ioutil.ReadAll(jobResp.Body) body, err := io.ReadAll(jobResp.Body)
if err != nil { if err != nil {
log.Debugf("Coudln't read export job response; %v", err) log.Debugf("Coudln't read export job response; %v", err)
return nil, err return nil, err
@@ -500,6 +548,82 @@ func (am *Auth0Manager) GetAllAccounts() (map[string][]*UserData, error) {
return nil, fmt.Errorf("failed extracting user profiles from auth0") return nil, fmt.Errorf("failed extracting user profiles from auth0")
} }
// GetUserByEmail searches users with a given email. If no users have been found, this function returns an empty list.
// This function can return multiple users. This is due to the Auth0 internals - there could be multiple users with
// the same email but different connections that are considered as separate accounts (e.g., Google and username/password).
func (am *Auth0Manager) GetUserByEmail(email string) ([]*UserData, error) {
jwtToken, err := am.credentials.Authenticate()
if err != nil {
return nil, err
}
reqURL := am.authIssuer + "/api/v2/users-by-email?email=" + email
body, err := doGetReq(am.httpClient, reqURL, jwtToken.AccessToken)
if err != nil {
return nil, err
}
userResp := []*UserData{}
err = am.helper.Unmarshal(body, &userResp)
if err != nil {
log.Debugf("Coudln't unmarshal export job response; %v", err)
return nil, err
}
return userResp, nil
}
// CreateUser creates a new user in Auth0 Idp and sends an invite
func (am *Auth0Manager) CreateUser(email string, name string, accountID string) (*UserData, error) {
payloadString, err := buildCreateUserRequestPayload(email, name, accountID)
if err != nil {
return nil, err
}
req, err := am.createPostRequest("/api/v2/users", payloadString)
if err != nil {
return nil, err
}
resp, err := am.httpClient.Do(req)
if err != nil {
log.Debugf("Couldn't get job response %v", err)
return nil, err
}
defer func() {
err = resp.Body.Close()
if err != nil {
log.Errorf("error while closing create user response body: %v", err)
}
}()
if !(resp.StatusCode == 200 || resp.StatusCode == 201) {
return nil, fmt.Errorf("unable to create user, statusCode %d", resp.StatusCode)
}
var createResp UserData
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Debugf("Coudln't read export job response; %v", err)
return nil, err
}
err = am.helper.Unmarshal(body, &createResp)
if err != nil {
log.Debugf("Coudln't unmarshal export job response; %v", err)
return nil, err
}
if createResp.ID == "" {
return nil, fmt.Errorf("couldn't create user: response %v", resp)
}
log.Debugf("created user %s in account %s", createResp.ID, accountID)
return &createResp, nil
}
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob. // checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
// If the status is "completed", then return the downloadLink // If the status is "completed", then return the downloadLink
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) { func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
@@ -572,6 +696,10 @@ func (am *Auth0Manager) downloadProfileExport(location string) (map[string][]*Us
ID: profile.UserID, ID: profile.UserID,
Name: profile.Name, Name: profile.Name,
Email: profile.Email, Email: profile.Email,
AppMetadata: AppMetadata{
WTAccountID: profile.AccountID,
WTPendingInvite: profile.PendingInvite,
},
}) })
} }
} }
@@ -605,7 +733,7 @@ func doGetReq(client ManagerHTTPClient, url, accessToken string) ([]byte, error)
return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode) return nil, fmt.Errorf("unable to get %s, statusCode %d", url, res.StatusCode)
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -4,7 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"io/ioutil" "io"
"net/http" "net/http"
"strings" "strings"
"testing" "testing"
@@ -22,13 +22,13 @@ type mockHTTPClient struct {
} }
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
body, err := ioutil.ReadAll(req.Body) body, err := io.ReadAll(req.Body)
if err == nil { if err == nil {
c.reqBody = string(body) c.reqBody = string(body)
} }
return &http.Response{ return &http.Response{
StatusCode: c.code, StatusCode: c.code,
Body: ioutil.NopCloser(strings.NewReader(c.resBody)), Body: io.NopCloser(strings.NewReader(c.resBody)),
}, c.err }, c.err
} }
@@ -130,7 +130,7 @@ func TestAuth0_RequestJWTToken(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
body, err := ioutil.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
assert.NoError(t, err, "unable to read the response body") assert.NoError(t, err, "unable to read the response body")
jwtToken := JWTToken{} jwtToken := JWTToken{}
@@ -178,7 +178,7 @@ func TestAuth0_ParseRequestJWTResponse(t *testing.T) {
for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} { for _, testCase := range []parseRequestJWTResponseTest{parseRequestJWTResponseTestCase1, parseRequestJWTResponseTestCase2} {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
rawBody := ioutil.NopCloser(strings.NewReader(testCase.inputResBody)) rawBody := io.NopCloser(strings.NewReader(testCase.inputResBody))
config := Auth0ClientConfig{} config := Auth0ClientConfig{}
@@ -320,7 +320,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
exp := 15 exp := 15
token := newTestJWT(t, exp) token := newTestJWT(t, exp)
appMetadata := AppMetadata{WTAccountId: "ok"} appMetadata := AppMetadata{WTAccountID: "ok"}
updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{ updateUserAppMetadataTestCase1 := updateUserAppMetadataTest{
name: "Bad Authentication", name: "Bad Authentication",
@@ -340,7 +340,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{ updateUserAppMetadataTestCase2 := updateUserAppMetadataTest{
name: "Bad Status Code", name: "Bad Status Code",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\": {\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountId), expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":false}}", appMetadata.WTAccountID),
appMetadata: appMetadata, appMetadata: appMetadata,
statusCode: 400, statusCode: 400,
helper: JsonParser{}, helper: JsonParser{},
@@ -363,7 +363,7 @@ func TestAuth0_UpdateUserAppMetadata(t *testing.T) {
updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{ updateUserAppMetadataTestCase4 := updateUserAppMetadataTest{
name: "Good request", name: "Good request",
inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp), inputReqBody: fmt.Sprintf("{\"access_token\":\"%s\",\"scope\":\"read:users\",\"expires_in\":%d,\"token_type\":\"Bearer\"}", token, exp),
expectedReqBody: fmt.Sprintf("{\"app_metadata\": {\"wt_account_id\":\"%s\"}}", appMetadata.WTAccountId), expectedReqBody: fmt.Sprintf("{\"app_metadata\":{\"wt_account_id\":\"%s\",\"wt_pending_invite\":false}}", appMetadata.WTAccountID),
appMetadata: appMetadata, appMetadata: appMetadata,
statusCode: 200, statusCode: 200,
helper: JsonParser{}, helper: JsonParser{},

View File

@@ -13,6 +13,8 @@ type Manager interface {
GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error) GetUserDataByID(userId string, appMetadata AppMetadata) (*UserData, error)
GetAccount(accountId string) ([]*UserData, error) GetAccount(accountId string) ([]*UserData, error)
GetAllAccounts() (map[string][]*UserData, error) GetAllAccounts() (map[string][]*UserData, error)
CreateUser(email string, name string, accountID string) (*UserData, error)
GetUserByEmail(email string) ([]*UserData, error)
} }
// Config an idp configuration struct to be loaded from management server's config file // Config an idp configuration struct to be loaded from management server's config file
@@ -38,16 +40,18 @@ type ManagerHelper interface {
} }
type UserData struct { type UserData struct {
Email string `json:"email"` Email string `json:"email"`
Name string `json:"name"` Name string `json:"name"`
ID string `json:"user_id"` ID string `json:"user_id"`
AppMetadata AppMetadata `json:"app_metadata"`
} }
// AppMetadata user app metadata to associate with a profile // AppMetadata user app metadata to associate with a profile
type AppMetadata struct { type AppMetadata struct {
// Wiretrustee account id to update in the IDP // WTAccountID is a NetBird (previously Wiretrustee) account id to update in the IDP
// maps to wt_account_id when json.marshal // maps to wt_account_id when json.marshal
WTAccountId string `json:"wt_account_id"` WTAccountID string `json:"wt_account_id,omitempty"`
WTPendingInvite bool `json:"wt_pending_invite"`
} }
// JWTToken a JWT object that holds information of a token // JWTToken a JWT object that holds information of a token

View File

@@ -1,6 +1,18 @@
package idp package idp
import "encoding/json" import (
"encoding/json"
"math/rand"
"strings"
)
var (
lowerCharSet = "abcdedfghijklmnopqrst"
upperCharSet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
specialCharSet = "!@#$%&*"
numberSet = "0123456789"
allCharSet = lowerCharSet + upperCharSet + specialCharSet + numberSet
)
type JsonParser struct{} type JsonParser struct{}
@@ -11,3 +23,37 @@ func (JsonParser) Marshal(v interface{}) ([]byte, error) {
func (JsonParser) Unmarshal(data []byte, v interface{}) error { func (JsonParser) Unmarshal(data []byte, v interface{}) error {
return json.Unmarshal(data, v) return json.Unmarshal(data, v)
} }
// GeneratePassword generates user password
func GeneratePassword(passwordLength, minSpecialChar, minNum, minUpperCase int) string {
var password strings.Builder
//Set special character
for i := 0; i < minSpecialChar; i++ {
random := rand.Intn(len(specialCharSet))
password.WriteString(string(specialCharSet[random]))
}
//Set numeric
for i := 0; i < minNum; i++ {
random := rand.Intn(len(numberSet))
password.WriteString(string(numberSet[random]))
}
//Set uppercase
for i := 0; i < minUpperCase; i++ {
random := rand.Intn(len(upperCharSet))
password.WriteString(string(upperCharSet[random]))
}
remainingLength := passwordLength - minSpecialChar - minNum - minUpperCase
for i := 0; i < remainingLength; i++ {
random := rand.Intn(len(allCharSet))
password.WriteString(string(allCharSet[random]))
}
inRune := []rune(password.String())
rand.Shuffle(len(inRune), func(i, j int) {
inRune[i], inRune[j] = inRune[j], inRune[i]
})
return string(inRune)
}

View File

@@ -2,7 +2,6 @@ package server_test
import ( import (
"context" "context"
"io/ioutil"
"math/rand" "math/rand"
"net" "net"
"os" "os"
@@ -45,7 +44,7 @@ var _ = Describe("Management service", func() {
level, _ := log.ParseLevel("Debug") level, _ := log.ParseLevel("Debug")
log.SetLevel(level) log.SetLevel(level)
var err error var err error
dataDir, err = ioutil.TempDir("", "wiretrustee_mgmt_test_tmp_*") dataDir, err = os.MkdirTemp("", "wiretrustee_mgmt_test_tmp_*")
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
err = util.CopyFileContents("testdata/store.json", filepath.Join(dataDir, "store.json")) err = util.CopyFileContents("testdata/store.json", filepath.Join(dataDir, "store.json"))

View File

@@ -0,0 +1,283 @@
// Package metrics gather anonymous information about the usage of NetBird management
package metrics
import (
"context"
"encoding/json"
"fmt"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/server"
log "github.com/sirupsen/logrus"
"io"
"net/http"
"strings"
"time"
)
const (
// PayloadEvent identifies an event type
PayloadEvent = "self-hosted stats"
// payloadEndpoint metrics endpoint to send anonymous data
payloadEndpoint = "https://metrics.netbird.io"
// defaultPushInterval default interval to push metrics
defaultPushInterval = 24 * time.Hour
// requestTimeout http request timeout
requestTimeout = 30 * time.Second
)
type getTokenResponse struct {
PublicAPIToken string `json:"public_api_token"`
}
type pushPayload struct {
APIKey string `json:"api_key"`
DistinctID string `json:"distinct_id"`
Event string `json:"event"`
Properties properties `json:"properties"`
Timestamp time.Time `json:"timestamp"`
}
// properties metrics to push
type properties map[string]interface{}
// DataSource metric data source
type DataSource interface {
GetAllAccounts() []*server.Account
}
// ConnManager peer connection manager that holds state for current active connections
type ConnManager interface {
GetAllConnectedPeers() map[string]struct{}
}
// Worker metrics collector and pusher
type Worker struct {
ctx context.Context
id string
dataSource DataSource
connManager ConnManager
startupTime time.Time
lastRun time.Time
}
// NewWorker returns a metrics worker
func NewWorker(ctx context.Context, id string, dataSource DataSource, connManager ConnManager) *Worker {
currentTime := time.Now()
return &Worker{
ctx: ctx,
id: id,
dataSource: dataSource,
connManager: connManager,
startupTime: currentTime,
lastRun: currentTime,
}
}
// Run runs the metrics worker
func (w *Worker) Run() {
pushTicker := time.NewTicker(defaultPushInterval)
for {
select {
case <-w.ctx.Done():
return
case <-pushTicker.C:
err := w.sendMetrics()
if err != nil {
log.Error(err)
}
}
}
}
func (w *Worker) sendMetrics() error {
ctx, cancel := context.WithTimeout(w.ctx, requestTimeout)
defer cancel()
apiKey, err := getAPIKey(ctx)
if err != nil {
return err
}
payload := w.generatePayload(apiKey)
payloadString, err := buildMetricsPayload(payload)
if err != nil {
return err
}
httpClient := http.Client{}
exportJobReq, err := createPostRequest(ctx, payloadEndpoint+"/capture/", payloadString)
if err != nil {
return fmt.Errorf("unable to create metrics post request %v", err)
}
jobResp, err := httpClient.Do(exportJobReq)
if err != nil {
return fmt.Errorf("unable to push metrics %v", err)
}
defer func() {
err = jobResp.Body.Close()
if err != nil {
log.Errorf("error while closing update metrics response body: %v", err)
}
}()
if jobResp.StatusCode != 200 {
return fmt.Errorf("unable to push anonymous metrics, got statusCode %d", jobResp.StatusCode)
}
log.Infof("sent anonymous metrics, next push will happen in %s. "+
"You can disable these metrics by running with flag --disable-anonymous-metrics,"+
" see more information at https://netbird.io/docs/FAQ/metrics-collection", defaultPushInterval)
return nil
}
func (w *Worker) generatePayload(apiKey string) pushPayload {
properties := w.generateProperties()
return pushPayload{
APIKey: apiKey,
DistinctID: w.id,
Event: PayloadEvent,
Properties: properties,
Timestamp: time.Now(),
}
}
func (w *Worker) generateProperties() properties {
var (
uptime float64
accounts int
users int
peers int
setupKeysUsage int
activePeersLastDay int
osPeers map[string]int
userPeers int
rules int
groups int
routes int
nameservers int
version string
)
start := time.Now()
metricsProperties := make(properties)
osPeers = make(map[string]int)
uptime = time.Since(w.startupTime).Seconds()
connections := w.connManager.GetAllConnectedPeers()
version = system.NetbirdVersion()
for _, account := range w.dataSource.GetAllAccounts() {
accounts++
users = users + len(account.Users)
rules = rules + len(account.Rules)
groups = groups + len(account.Groups)
routes = routes + len(account.Routes)
nameservers = nameservers + len(account.NameServerGroups)
for _, key := range account.SetupKeys {
setupKeysUsage = setupKeysUsage + key.UsedTimes
}
for _, peer := range account.Peers {
peers++
if peer.SetupKey != "" {
userPeers++
}
_, connected := connections[peer.Key]
if connected || peer.Status.LastSeen.After(w.lastRun) {
activePeersLastDay++
}
osKey := strings.ToLower(fmt.Sprintf("peer_os_%s", peer.Meta.GoOS))
osCount := osPeers[osKey]
osPeers[osKey] = osCount + 1
}
}
metricsProperties["uptime"] = uptime
metricsProperties["accounts"] = accounts
metricsProperties["users"] = users
metricsProperties["peers"] = peers
metricsProperties["setup_keys_usage"] = setupKeysUsage
metricsProperties["active_peers_last_day"] = activePeersLastDay
metricsProperties["user_peers"] = userPeers
metricsProperties["rules"] = rules
metricsProperties["groups"] = groups
metricsProperties["routes"] = routes
metricsProperties["nameservers"] = nameservers
metricsProperties["version"] = version
for os, count := range osPeers {
metricsProperties[os] = count
}
metricsProperties["metric_generation_time"] = time.Since(start).Milliseconds()
return metricsProperties
}
func getAPIKey(ctx context.Context) (string, error) {
httpClient := http.Client{}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, payloadEndpoint+"/GetToken", nil)
if err != nil {
return "", fmt.Errorf("unable to create request for metrics public api token %v", err)
}
response, err := httpClient.Do(req)
if err != nil {
return "", fmt.Errorf("unable to request metrics public api token %v", err)
}
defer func() {
err = response.Body.Close()
if err != nil {
log.Errorf("error while closing metrics token response body: %v", err)
}
}()
if response.StatusCode != 200 {
return "", fmt.Errorf("unable to retrieve metrics token, statusCode %d", response.StatusCode)
}
body, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("coudln't get metrics token response; %v", err)
}
var tokenResponse getTokenResponse
err = json.Unmarshal(body, &tokenResponse)
if err != nil {
return "", fmt.Errorf("coudln't parse metrics public api token; %v", err)
}
return tokenResponse.PublicAPIToken, nil
}
func buildMetricsPayload(payload pushPayload) (string, error) {
str, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("unable to marshal metrics payload, got err: %v", err)
}
return string(str), nil
}
func createPostRequest(ctx context.Context, endpoint string, payloadStr string) (*http.Request, error) {
reqURL := endpoint
payload := strings.NewReader(payloadStr)
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, payload)
if err != nil {
return nil, err
}
req.Header.Add("content-type", "application/json")
return req, nil
}

View File

@@ -1,6 +1,7 @@
package mock_server package mock_server
import ( import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -10,49 +11,56 @@ import (
) )
type MockAccountManager struct { type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error)
GetAccountByUserFunc func(userId string) (*server.Account, error) GetAccountByUserFunc func(userId string) (*server.Account, error)
CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error) CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error)
GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error)
GetAccountByIdFunc func(accountId string) (*server.Account, error) GetAccountByIdFunc func(accountId string) (*server.Account, error)
GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error)
GetAccountWithAuthorizationClaimsFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExistsFunc func(accountId string) (*bool, error)
AccountExistsFunc func(accountId string) (*bool, error) GetPeerFunc func(peerKey string) (*server.Peer, error)
GetPeerFunc func(peerKey string) (*server.Peer, error) MarkPeerConnectedFunc func(peerKey string, connected bool) error
MarkPeerConnectedFunc func(peerKey string, connected bool) error RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error)
RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error) DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error)
DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error) GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error)
GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error) GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error)
GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error)
GetPeerNetworkFunc func(peerKey string) (*server.Network, error) AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, error)
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, error) GetGroupFunc func(accountID, groupID string) (*server.Group, error)
GetGroupFunc func(accountID, groupID string) (*server.Group, error) SaveGroupFunc func(accountID string, group *server.Group) error
SaveGroupFunc func(accountID string, group *server.Group) error UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error)
UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) DeleteGroupFunc func(accountID, groupID string) error
DeleteGroupFunc func(accountID, groupID string) error ListGroupsFunc func(accountID string) ([]*server.Group, error)
ListGroupsFunc func(accountID string) ([]*server.Group, error) GroupAddPeerFunc func(accountID, groupID, peerKey string) error
GroupAddPeerFunc func(accountID, groupID, peerKey string) error GroupDeletePeerFunc func(accountID, groupID, peerKey string) error
GroupDeletePeerFunc func(accountID, groupID, peerKey string) error GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error)
GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error) GetRuleFunc func(accountID, ruleID string) (*server.Rule, error)
GetRuleFunc func(accountID, ruleID string) (*server.Rule, error) SaveRuleFunc func(accountID string, rule *server.Rule) error
SaveRuleFunc func(accountID string, rule *server.Rule) error UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error)
UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error) DeleteRuleFunc func(accountID, ruleID string) error
DeleteRuleFunc func(accountID, ruleID string) error ListRulesFunc func(accountID string) ([]*server.Rule, error)
ListRulesFunc func(accountID string) ([]*server.Rule, error) GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error)
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error) UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error
UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error
UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error)
UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error) CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error)
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) GetRouteFunc func(accountID, routeID string) (*route.Route, error)
GetRouteFunc func(accountID, routeID string) (*route.Route, error) SaveRouteFunc func(accountID string, route *route.Route) error
SaveRouteFunc func(accountID string, route *route.Route) error UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) DeleteRouteFunc func(accountID, routeID string) error
DeleteRouteFunc func(accountID, routeID string) error ListRoutesFunc func(accountID string) ([]*route.Route, error)
ListRoutesFunc func(accountID string) ([]*route.Route, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error)
SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error) ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error)
ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error) SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error)
SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error) GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error)
SaveNameServerGroupFunc func(accountID string, nsGroupToSave *nbdns.NameServerGroup) error
UpdateNameServerGroupFunc func(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
DeleteNameServerGroupFunc func(accountID, nsGroupID string) error
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(accountID string, key *server.UserInfo) (*server.UserInfo, error)
GetAccountFromTokenFunc func(claims jwtclaims.AuthorizationClaims) (*server.Account, error)
} }
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
@@ -119,19 +127,6 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId(
) )
} }
// GetAccountWithAuthorizationClaims mock implementation of GetAccountWithAuthorizationClaims from server.AccountManager interface
func (am *MockAccountManager) GetAccountWithAuthorizationClaims(
claims jwtclaims.AuthorizationClaims,
) (*server.Account, error) {
if am.GetAccountWithAuthorizationClaimsFunc != nil {
return am.GetAccountWithAuthorizationClaimsFunc(claims)
}
return nil, status.Errorf(
codes.Unimplemented,
"method GetAccountWithAuthorizationClaims is not implemented",
)
}
// AccountExists mock implementation of AccountExists from server.AccountManager interface // AccountExists mock implementation of AccountExists from server.AccountManager interface
func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) { func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) {
if am.AccountExistsFunc != nil { if am.AccountExistsFunc != nil {
@@ -430,3 +425,67 @@ func (am *MockAccountManager) SaveUser(accountID string, user *server.User) (*se
} }
return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented") return nil, status.Errorf(codes.Unimplemented, "method SaveUser is not implemented")
} }
// GetNameServerGroup mocks GetNameServerGroup of the AccountManager interface
func (am *MockAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
if am.GetNameServerGroupFunc != nil {
return am.GetNameServerGroupFunc(accountID, nsGroupID)
}
return nil, nil
}
// CreateNameServerGroup mocks CreateNameServerGroup of the AccountManager interface
func (am *MockAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) {
if am.CreateNameServerGroupFunc != nil {
return am.CreateNameServerGroupFunc(accountID, name, description, nameServerList, groups, enabled)
}
return nil, nil
}
// SaveNameServerGroup mocks SaveNameServerGroup of the AccountManager interface
func (am *MockAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error {
if am.SaveNameServerGroupFunc != nil {
return am.SaveNameServerGroupFunc(accountID, nsGroupToSave)
}
return nil
}
// UpdateNameServerGroup mocks UpdateNameServerGroup of the AccountManager interface
func (am *MockAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
if am.UpdateNameServerGroupFunc != nil {
return am.UpdateNameServerGroupFunc(accountID, nsGroupID, operations)
}
return nil, nil
}
// DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface
func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error {
if am.DeleteNameServerGroupFunc != nil {
return am.DeleteNameServerGroupFunc(accountID, nsGroupID)
}
return nil
}
// ListNameServerGroups mocks ListNameServerGroups of the AccountManager interface
func (am *MockAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) {
if am.ListNameServerGroupsFunc != nil {
return am.ListNameServerGroupsFunc(accountID)
}
return nil, nil
}
// CreateUser mocks CreateUser of the AccountManager interface
func (am *MockAccountManager) CreateUser(accountID string, invite *server.UserInfo) (*server.UserInfo, error) {
if am.CreateUserFunc != nil {
return am.CreateUserFunc(accountID, invite)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
if am.GetAccountFromTokenFunc != nil {
return am.GetAccountFromTokenFunc(claims)
}
return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
}

View File

@@ -0,0 +1,333 @@
package server
import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/rs/xid"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"strconv"
"unicode/utf8"
)
const (
// UpdateNameServerGroupName indicates a nameserver group name update operation
UpdateNameServerGroupName NameServerGroupUpdateOperationType = iota
// UpdateNameServerGroupDescription indicates a nameserver group description update operation
UpdateNameServerGroupDescription
// UpdateNameServerGroupNameServers indicates a nameserver group nameservers list update operation
UpdateNameServerGroupNameServers
// UpdateNameServerGroupGroups indicates a nameserver group' groups update operation
UpdateNameServerGroupGroups
// UpdateNameServerGroupEnabled indicates a nameserver group status update operation
UpdateNameServerGroupEnabled
)
// NameServerGroupUpdateOperationType operation type
type NameServerGroupUpdateOperationType int
func (t NameServerGroupUpdateOperationType) String() string {
switch t {
case UpdateNameServerGroupDescription:
return "UpdateNameServerGroupDescription"
case UpdateNameServerGroupName:
return "UpdateNameServerGroupName"
case UpdateNameServerGroupNameServers:
return "UpdateNameServerGroupNameServers"
case UpdateNameServerGroupGroups:
return "UpdateNameServerGroupGroups"
case UpdateNameServerGroupEnabled:
return "UpdateNameServerGroupEnabled"
default:
return "InvalidOperation"
}
}
// NameServerGroupUpdateOperation operation object with type and values to be applied
type NameServerGroupUpdateOperation struct {
Type NameServerGroupUpdateOperationType
Values []string
}
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
nsGroup, found := account.NameServerGroups[nsGroupID]
if found {
return nsGroup.Copy(), nil
}
return nil, status.Errorf(codes.NotFound, "nameserver group with ID %s not found", nsGroupID)
}
// CreateNameServerGroup creates and saves a new nameserver group
func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, enabled bool) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
newNSGroup := &nbdns.NameServerGroup{
ID: xid.New().String(),
Name: name,
Description: description,
NameServers: nameServerList,
Groups: groups,
Enabled: enabled,
}
err = validateNameServerGroup(false, newNSGroup, account)
if err != nil {
return nil, err
}
if account.NameServerGroups == nil {
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup)
}
account.NameServerGroups[newNSGroup.ID] = newNSGroup
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
return newNSGroup.Copy(), nil
}
// SaveNameServerGroup saves nameserver group
func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error {
am.mux.Lock()
defer am.mux.Unlock()
if nsGroupToSave == nil {
return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil")
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
}
err = validateNameServerGroup(true, nsGroupToSave, account)
if err != nil {
return err
}
account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return err
}
return nil
}
// UpdateNameServerGroup updates existing nameserver group with set of operations
func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
if len(operations) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "operations shouldn't be empty")
}
nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID]
if !ok {
return nil, status.Errorf(codes.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
}
newNSGroup := nsGroupToUpdate.Copy()
for _, operation := range operations {
valuesCount := len(operation.Values)
if valuesCount < 1 {
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
}
for _, value := range operation.Values {
if value == "" {
return nil, status.Errorf(codes.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
}
}
switch operation.Type {
case UpdateNameServerGroupDescription:
newNSGroup.Description = operation.Values[0]
case UpdateNameServerGroupName:
if valuesCount > 1 {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
}
err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups)
if err != nil {
return nil, err
}
newNSGroup.Name = operation.Values[0]
case UpdateNameServerGroupNameServers:
var nsList []nbdns.NameServer
for _, url := range operation.Values {
ns, err := nbdns.ParseNameServerURL(url)
if err != nil {
return nil, err
}
nsList = append(nsList, ns)
}
err = validateNSList(nsList)
if err != nil {
return nil, err
}
newNSGroup.NameServers = nsList
case UpdateNameServerGroupGroups:
err = validateGroups(operation.Values, account.Groups)
if err != nil {
return nil, err
}
newNSGroup.Groups = operation.Values
case UpdateNameServerGroupEnabled:
enabled, err := strconv.ParseBool(operation.Values[0])
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
}
newNSGroup.Enabled = enabled
}
}
account.NameServerGroups[nsGroupID] = newNSGroup
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
return newNSGroup.Copy(), nil
}
// DeleteNameServerGroup deletes nameserver group with nsGroupID
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return status.Errorf(codes.NotFound, "account not found")
}
delete(account.NameServerGroups, nsGroupID)
account.Network.IncSerial()
err = am.Store.SaveAccount(account)
if err != nil {
return err
}
return nil
}
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
for _, item := range account.NameServerGroups {
nsGroups = append(nsGroups, item.Copy())
}
return nsGroups, nil
}
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
nsGroupID := ""
if existingGroup {
nsGroupID = nameserverGroup.ID
_, found := account.NameServerGroups[nsGroupID]
if !found {
return status.Errorf(codes.NotFound, "nameserver group with ID %s was not found", nsGroupID)
}
}
err := validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups)
if err != nil {
return err
}
err = validateNSList(nameserverGroup.NameServers)
if err != nil {
return err
}
err = validateGroups(nameserverGroup.Groups, account.Groups)
if err != nil {
return err
}
return nil
}
func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error {
if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" {
return status.Errorf(codes.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar)
}
for _, nsGroup := range nsGroupMap {
if name == nsGroup.Name && nsGroup.ID != nsGroupID {
return status.Errorf(codes.InvalidArgument, "a nameserver group with name %s already exist", name)
}
}
return nil
}
func validateNSList(list []nbdns.NameServer) error {
nsListLenght := len(list)
if nsListLenght == 0 || nsListLenght > 2 {
return status.Errorf(codes.InvalidArgument, "the list of nameservers should be 1 or 2, got %d", len(list))
}
return nil
}
func validateGroups(list []string, groups map[string]*Group) error {
if len(list) == 0 {
return status.Errorf(codes.InvalidArgument, "the list of group IDs should not be empty")
}
for _, id := range list {
if id == "" {
return status.Errorf(codes.InvalidArgument, "group ID should not be empty string")
}
found := false
for groupID := range groups {
if id == groupID {
found = true
break
}
}
if !found {
return status.Errorf(codes.InvalidArgument, "group id %s not found", id)
}
}
return nil
}

View File

@@ -0,0 +1,965 @@
package server
import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/stretchr/testify/require"
"net/netip"
"testing"
)
const (
group1ID = "group1"
group2ID = "group2"
existingNSGroupName = "existing"
existingNSGroupID = "existingNSGroup"
nsGroupPeer1Key = "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8="
nsGroupPeer2Key = "/yF0+vCfv+mRR5k0dca0TrGdO/oiNeAI58gToZm5NyI="
)
func TestCreateNameServerGroup(t *testing.T) {
type input struct {
name string
description string
enabled bool
groups []string
nameServers []nbdns.NameServer
}
testCases := []struct {
name string
inputArgs input
shouldCreate bool
errFunc require.ErrorAssertionFunc
expectedNSGroup *nbdns.NameServerGroup
}{
{
name: "Create A NS Group",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.NoError,
shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{
Name: "super",
Description: "super",
Groups: []string{group1ID},
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Enabled: true,
},
},
{
name: "Should Not Create If Name Exist",
inputArgs: input{
name: existingNSGroupName,
description: "super",
groups: []string{group1ID},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Create If Name Is Small",
inputArgs: input{
name: "",
description: "super",
groups: []string{group1ID},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Create If Name Is Large",
inputArgs: input{
name: "1234567890123456789012345678901234567890extra",
description: "super",
groups: []string{group1ID},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Create A NS Group With No Nameservers Should Fail",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
nameServers: []nbdns.NameServer{},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Create A NS Group With More Than 2 Nameservers Should Fail",
inputArgs: input{
name: "super",
description: "super",
groups: []string{group1ID},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.3.3"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Create If Groups Is Empty",
inputArgs: input{
name: "super",
description: "super",
groups: []string{},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Create If Group Doesn't Exist",
inputArgs: input{
name: "super",
description: "super",
groups: []string{"missingGroup"},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Create If Group ID Is Invalid",
inputArgs: input{
name: "super",
description: "super",
groups: []string{""},
nameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
enabled: true,
},
errFunc: require.Error,
shouldCreate: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
am, err := createNSManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
outNSGroup, err := am.CreateNameServerGroup(
account.Id,
testCase.inputArgs.name,
testCase.inputArgs.description,
testCase.inputArgs.nameServers,
testCase.inputArgs.groups,
testCase.inputArgs.enabled,
)
testCase.errFunc(t, err)
if !testCase.shouldCreate {
return
}
// assign generated ID
testCase.expectedNSGroup.ID = outNSGroup.ID
if !testCase.expectedNSGroup.IsEqual(outNSGroup) {
t.Errorf("new nameserver group didn't match expected ns group:\nGot %#v\nExpected:%#v\n", outNSGroup, testCase.expectedNSGroup)
}
})
}
}
func TestSaveNameServerGroup(t *testing.T) {
existingNSGroup := &nbdns.NameServerGroup{
ID: "testingNSGroup",
Name: "super",
Description: "super",
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{group1ID},
Enabled: true,
}
validGroups := []string{group2ID}
invalidGroups := []string{"nonExisting"}
validNameServerList := []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
}
invalidNameServerListLarge := []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.3.3"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
}
invalidID := "doesntExist"
validName := "12345678901234567890qw"
invalidNameLarge := "12345678901234567890qwertyuiopqwertyuiop1"
invalidNameSmall := ""
invalidNameExisting := existingNSGroupName
testCases := []struct {
name string
existingNSGroup *nbdns.NameServerGroup
newID *string
newName *string
newNSList []nbdns.NameServer
newGroups []string
skipCopying bool
shouldCreate bool
errFunc require.ErrorAssertionFunc
expectedNSGroup *nbdns.NameServerGroup
}{
{
name: "Should Update Name Server Group",
existingNSGroup: existingNSGroup,
newName: &validName,
newGroups: validGroups,
newNSList: validNameServerList,
errFunc: require.NoError,
shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{
ID: "testingNSGroup",
Name: validName,
Description: "super",
NameServers: validNameServerList,
Groups: validGroups,
Enabled: true,
},
},
{
name: "Should Not Update If Name Is Small",
existingNSGroup: existingNSGroup,
newName: &invalidNameSmall,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Name Is Large",
existingNSGroup: existingNSGroup,
newName: &invalidNameLarge,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Name Exists",
existingNSGroup: existingNSGroup,
newName: &invalidNameExisting,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If ID Don't Exist",
existingNSGroup: existingNSGroup,
newID: &invalidID,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Nameserver List Is Small",
existingNSGroup: existingNSGroup,
newNSList: []nbdns.NameServer{},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Nameserver List Is Large",
existingNSGroup: existingNSGroup,
newNSList: invalidNameServerListLarge,
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Groups List Is Empty",
existingNSGroup: existingNSGroup,
newGroups: []string{},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Groups List Has Empty ID",
existingNSGroup: existingNSGroup,
newGroups: []string{""},
errFunc: require.Error,
shouldCreate: false,
},
{
name: "Should Not Update If Groups List Has Non Existing Group ID",
existingNSGroup: existingNSGroup,
newGroups: invalidGroups,
errFunc: require.Error,
shouldCreate: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
am, err := createNSManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
err = am.Store.SaveAccount(account)
if err != nil {
t.Error("account should be saved")
}
var nsGroupToSave *nbdns.NameServerGroup
if !testCase.skipCopying {
nsGroupToSave = testCase.existingNSGroup.Copy()
if testCase.newID != nil {
nsGroupToSave.ID = *testCase.newID
}
if testCase.newName != nil {
nsGroupToSave.Name = *testCase.newName
}
if testCase.newGroups != nil {
nsGroupToSave.Groups = testCase.newGroups
}
if testCase.newNSList != nil {
nsGroupToSave.NameServers = testCase.newNSList
}
}
err = am.SaveNameServerGroup(account.Id, nsGroupToSave)
testCase.errFunc(t, err)
if !testCase.shouldCreate {
return
}
savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID]
require.True(t, saved)
if !testCase.expectedNSGroup.IsEqual(savedNSGroup) {
t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup)
}
})
}
}
func TestUpdateNameServerGroup(t *testing.T) {
nsGroupID := "testingNSGroup"
existingNSGroup := &nbdns.NameServerGroup{
ID: nsGroupID,
Name: "super",
Description: "super",
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{group1ID},
Enabled: true,
}
testCases := []struct {
name string
existingNSGroup *nbdns.NameServerGroup
nsGroupID string
operations []NameServerGroupUpdateOperation
shouldCreate bool
errFunc require.ErrorAssertionFunc
expectedNSGroup *nbdns.NameServerGroup
}{
{
name: "Should Update Single Property",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{"superNew"},
},
},
errFunc: require.NoError,
shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{
ID: nsGroupID,
Name: "superNew",
Description: "super",
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{group1ID},
Enabled: true,
},
},
{
name: "Should Update Multiple Properties",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{"superNew"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupDescription,
Values: []string{"superDescription"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupNameServers,
Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53"},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupGroups,
Values: []string{group1ID, group2ID},
},
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupEnabled,
Values: []string{"false"},
},
},
errFunc: require.NoError,
shouldCreate: true,
expectedNSGroup: &nbdns.NameServerGroup{
ID: nsGroupID,
Name: "superNew",
Description: "superDescription",
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("127.0.0.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{group1ID, group2ID},
Enabled: false,
},
},
{
name: "Should Not Update On Invalid ID",
existingNSGroup: existingNSGroup,
nsGroupID: "nonExistingNSGroup",
errFunc: require.Error,
},
{
name: "Should Not Update On Empty Operations",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{},
errFunc: require.Error,
},
{
name: "Should Not Update On Empty Values",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
},
},
errFunc: require.Error,
},
{
name: "Should Not Update On Empty String",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{""},
},
},
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid Name Large String",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid On Existing Name",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{existingNSGroupName},
},
},
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid On Multiple Name Values",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupName,
Values: []string{"nameOne", "nameTwo"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid Boolean",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupEnabled,
Values: []string{"yes"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid Nameservers Wrong Schema",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupNameServers,
Values: []string{"https://127.0.0.1:53"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid Nameservers Wrong IP",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupNameServers,
Values: []string{"udp://8.8.8.300:53"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Update On Large Number Of Nameservers",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupNameServers,
Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53", "udp://8.8.4.4:53"},
},
},
errFunc: require.Error,
},
{
name: "Should Not Update On Invalid GroupID",
existingNSGroup: existingNSGroup,
nsGroupID: existingNSGroup.ID,
operations: []NameServerGroupUpdateOperation{
NameServerGroupUpdateOperation{
Type: UpdateNameServerGroupGroups,
Values: []string{"nonExistingGroupID"},
},
},
errFunc: require.Error,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
am, err := createNSManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
err = am.Store.SaveAccount(account)
if err != nil {
t.Error("account should be saved")
}
updatedRoute, err := am.UpdateNameServerGroup(account.Id, testCase.nsGroupID, testCase.operations)
testCase.errFunc(t, err)
if !testCase.shouldCreate {
return
}
testCase.expectedNSGroup.ID = updatedRoute.ID
if !testCase.expectedNSGroup.IsEqual(updatedRoute) {
t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedNSGroup)
}
})
}
}
func TestDeleteNameServerGroup(t *testing.T) {
nsGroupID := "testingNSGroup"
testingNSGroup := &nbdns.NameServerGroup{
ID: nsGroupID,
Name: "super",
Description: "super",
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("1.1.2.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{group1ID},
Enabled: true,
}
am, err := createNSManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
account.NameServerGroups[testingNSGroup.ID] = testingNSGroup
err = am.Store.SaveAccount(account)
if err != nil {
t.Error("failed to save account")
}
err = am.DeleteNameServerGroup(account.Id, testingNSGroup.ID)
if err != nil {
t.Error("deleting nameserver group failed with error: ", err)
}
savedAccount, err := am.Store.GetAccount(account.Id)
if err != nil {
t.Error("failed to retrieve saved account with error: ", err)
}
_, found := savedAccount.NameServerGroups[testingNSGroup.ID]
if found {
t.Error("nameserver group shouldn't be found after delete")
}
}
func TestGetNameServerGroup(t *testing.T) {
am, err := createNSManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestNSAccount(t, am)
if err != nil {
t.Error("failed to init testing account")
}
foundGroup, err := am.GetNameServerGroup(account.Id, existingNSGroupID)
if err != nil {
t.Error("getting existing nameserver group failed with error: ", err)
}
if foundGroup == nil {
t.Error("got a nil group while getting nameserver group with ID")
}
_, err = am.GetNameServerGroup(account.Id, "not existing")
if err == nil {
t.Error("getting not existing nameserver group should return error, got nil")
}
}
func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
store, err := createNSStore(t)
if err != nil {
return nil, err
}
return BuildManager(store, NewPeersUpdateManager(), nil)
}
func createNSStore(t *testing.T) (Store, error) {
dataDir := t.TempDir()
store, err := NewStore(dataDir)
if err != nil {
return nil, err
}
return store, nil
}
func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) {
peer1 := &Peer{
Key: nsGroupPeer1Key,
Name: "test-host1@netbird.io",
Meta: PeerSystemMeta{
Hostname: "test-host1@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
}
peer2 := &Peer{
Key: nsGroupPeer2Key,
Name: "test-host2@netbird.io",
Meta: PeerSystemMeta{
Hostname: "test-host2@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
Platform: "x86_64",
OS: "Ubuntu",
WtVersion: "development",
UIVersion: "development",
},
}
existingNSGroup := nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: existingNSGroupName,
Description: "",
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
},
Groups: []string{group1ID},
Enabled: true,
}
accountID := "testingAcc"
userID := "testingUser"
domain := "example.com"
account := newAccountWithId(accountID, userID, domain)
account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup
defaultGroup, err := account.GetGroupAll()
if err != nil {
return nil, err
}
newGroup1 := defaultGroup.Copy()
newGroup1.ID = group1ID
newGroup2 := defaultGroup.Copy()
newGroup2.ID = group2ID
account.Groups[newGroup1.ID] = newGroup1
account.Groups[newGroup2.ID] = newGroup2
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
_, err = am.AddPeer("", userID, peer1)
if err != nil {
return nil, err
}
_, err = am.AddPeer("", userID, peer2)
if err != nil {
return nil, err
}
return account, nil
}

View File

@@ -21,4 +21,6 @@ type Store interface {
SaveAccount(account *Account) error SaveAccount(account *Account) error
GetPeerRoutes(peerKey string) ([]*route.Route, error) GetPeerRoutes(peerKey string) ([]*route.Route, error)
GetRoutesByPrefix(accountID string, prefix netip.Prefix) ([]*route.Route, error) GetRoutesByPrefix(accountID string, prefix netip.Prefix) ([]*route.Route, error)
GetInstallationID() string
SaveInstallationID(id string) error
} }

View File

@@ -70,3 +70,14 @@ func (p *PeersUpdateManager) CloseChannel(peerKey string) {
log.Debugf("closed updates channel of a peer %s", peerKey) log.Debugf("closed updates channel of a peer %s", peerKey)
} }
// GetAllConnectedPeers returns a copy of the connected peers map
func (p *PeersUpdateManager) GetAllConnectedPeers() map[string]struct{} {
p.channelsMux.Lock()
defer p.channelsMux.Unlock()
m := make(map[string]struct{})
for key := range p.peerChannels {
m[key] = struct{}{}
}
return m
}

View File

@@ -14,6 +14,10 @@ const (
UserRoleAdmin UserRole = "admin" UserRoleAdmin UserRole = "admin"
UserRoleUser UserRole = "user" UserRoleUser UserRole = "user"
UserRoleUnknown UserRole = "unknown" UserRoleUnknown UserRole = "unknown"
UserStatusActive UserStatus = "active"
UserStatusDisabled UserStatus = "disabled"
UserStatusInvited UserStatus = "invited"
) )
// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown // StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown
@@ -28,7 +32,10 @@ func StrRoleToUserRole(strRole string) UserRole {
} }
} }
// UserRole is the role of the User // UserStatus is the status of a User
type UserStatus string
// UserRole is the role of a User
type UserRole string type UserRole string
// User represents a user of the system // User represents a user of the system
@@ -53,24 +60,31 @@ func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) {
Name: "", Name: "",
Role: string(u.Role), Role: string(u.Role),
AutoGroups: u.AutoGroups, AutoGroups: u.AutoGroups,
Status: string(UserStatusActive),
}, nil }, nil
} }
if userData.ID != u.Id { if userData.ID != u.Id {
return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id) return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id)
} }
userStatus := UserStatusActive
if userData.AppMetadata.WTPendingInvite {
userStatus = UserStatusInvited
}
return &UserInfo{ return &UserInfo{
ID: u.Id, ID: u.Id,
Email: userData.Email, Email: userData.Email,
Name: userData.Name, Name: userData.Name,
Role: string(u.Role), Role: string(u.Role),
AutoGroups: autoGroups, AutoGroups: autoGroups,
Status: string(userStatus),
}, nil }, nil
} }
// Copy the user // Copy the user
func (u *User) Copy() *User { func (u *User) Copy() *User {
autoGroups := []string{} autoGroups := make([]string, 0)
autoGroups = append(autoGroups, u.AutoGroups...) autoGroups = append(autoGroups, u.AutoGroups...)
return &User{ return &User{
Id: u.Id, Id: u.Id,
@@ -98,6 +112,70 @@ func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin) return NewUser(id, UserRoleAdmin)
} }
// CreateUser creates a new user under the given account. Effectively this is a user invite.
func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) (*UserInfo, error) {
am.mux.Lock()
defer am.mux.Unlock()
if am.idpManager == nil {
return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites")
}
if invite == nil {
return nil, fmt.Errorf("provided user update is nil")
}
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, Errorf(AccountNotFound, "account %s doesn't exist", accountID)
}
// check if the user is already registered with this email => reject
user, err := am.lookupUserInCacheByEmail(invite.Email, accountID)
if err != nil {
return nil, err
}
if user != nil {
return nil, Errorf(UserAlreadyExists, "user has an existing account")
}
users, err := am.idpManager.GetUserByEmail(invite.Email)
if err != nil {
return nil, err
}
if len(users) > 0 {
return nil, Errorf(UserAlreadyExists, "user has an existing account")
}
idpUser, err := am.idpManager.CreateUser(invite.Email, invite.Name, accountID)
if err != nil {
return nil, err
}
role := StrRoleToUserRole(invite.Role)
newUser := &User{
Id: idpUser.ID,
Role: role,
AutoGroups: invite.AutoGroups,
}
account.Users[idpUser.ID] = newUser
err = am.Store.SaveAccount(account)
if err != nil {
return nil, err
}
_, err = am.refreshCache(account.Id)
if err != nil {
return nil, err
}
return newUser.toUserInfo(idpUser)
}
// SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error. // SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error.
// Only User.AutoGroups field is allowed to be updated for now. // Only User.AutoGroups field is allowed to be updated for now.
func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) { func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) {
@@ -138,10 +216,13 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User
} }
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
userData, err := am.lookupUserInCache(newUser, accountID) userData, err := am.lookupUserInCache(newUser.Id, account)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if userData == nil {
return nil, status.Errorf(codes.NotFound, "user %s not found in the IdP", newUser.Id)
}
return newUser.toUserInfo(userData) return newUser.toUserInfo(userData)
} }
return newUser.toUserInfo(nil) return newUser.toUserInfo(nil)
@@ -194,7 +275,7 @@ func (am *DefaultAccountManager) GetAccountByUser(userId string) (*Account, erro
// IsUserAdmin flag for current user authenticated by JWT token // IsUserAdmin flag for current user authenticated by JWT token
func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) {
account, err := am.GetAccountWithAuthorizationClaims(claims) account, err := am.GetAccountFromToken(claims)
if err != nil { if err != nil {
return false, fmt.Errorf("get account: %v", err) return false, fmt.Errorf("get account: %v", err)
} }
@@ -216,7 +297,11 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI
queriedUsers := make([]*idp.UserData, 0) queriedUsers := make([]*idp.UserData, 0)
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
queriedUsers, err = am.lookupCache(account.Users, accountID) users := make(map[string]struct{}, len(account.Users))
for _, user := range account.Users {
users[user.Id] = struct{}{}
}
queriedUsers, err = am.lookupCache(users, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -87,7 +87,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
}, nil }, nil
} }
//defaultBackoff is a basic backoff mechanism for general issues // defaultBackoff is a basic backoff mechanism for general issues
func defaultBackoff(ctx context.Context) backoff.BackOff { func defaultBackoff(ctx context.Context) backoff.BackOff {
return backoff.WithContext(&backoff.ExponentialBackOff{ return backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
@@ -121,9 +121,11 @@ func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
return fmt.Errorf("connection to signal is not ready and in %s state", connState) return fmt.Errorf("connection to signal is not ready and in %s state", connState)
} }
// connect to Signal stream identifying ourselves with a public Wireguard key // connect to Signal stream identifying ourselves with a public WireGuard key
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management) // todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
stream, err := c.connect(c.key.PublicKey().String()) ctx, cancelStream := context.WithCancel(c.ctx)
defer cancelStream()
stream, err := c.connect(ctx, c.key.PublicKey().String())
if err != nil { if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err) log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
return err return err
@@ -180,15 +182,13 @@ func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
return c.connectedCh return c.connectedCh
} }
func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) { func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
c.stream = nil c.stream = nil
// add key fingerprint to the request header to be identified on the server side // add key fingerprint to the request header to be identified on the server side
md := metadata.New(map[string]string{proto.HeaderId: key}) md := metadata.New(map[string]string{proto.HeaderId: key})
ctx := metadata.NewOutgoingContext(c.ctx, md) metaCtx := metadata.NewOutgoingContext(ctx, md)
stream, err := c.realClient.ConnectStream(metaCtx, grpc.WaitForReady(true))
stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true))
c.stream = stream c.stream = stream
if err != nil { if err != nil {
return nil, err return nil, err